2023-02-03 00:05:28 -08:00
|
|
|
"""Qdrant reader."""
|
|
|
|
|
2023-05-04 23:06:58 -07:00
|
|
|
from typing import List, Optional, cast, Dict
|
2023-02-03 00:05:28 -08:00
|
|
|
|
2023-02-20 21:46:58 -08:00
|
|
|
from llama_index.readers.base import BaseReader
|
|
|
|
from llama_index.readers.schema.base import Document
|
2023-02-03 00:05:28 -08:00
|
|
|
|
|
|
|
|
|
|
|
class QdrantReader(BaseReader):
|
|
|
|
"""Qdrant reader.
|
|
|
|
|
|
|
|
Retrieve documents from existing Qdrant collections.
|
|
|
|
|
|
|
|
Args:
|
2023-05-04 23:06:58 -07:00
|
|
|
location:
|
|
|
|
If `:memory:` - use in-memory Qdrant instance.
|
|
|
|
If `str` - use it as a `url` parameter.
|
|
|
|
If `None` - use default values for `host` and `port`.
|
|
|
|
url:
|
|
|
|
either host or str of
|
|
|
|
"Optional[scheme], host, Optional[port], Optional[prefix]".
|
|
|
|
Default: `None`
|
2023-02-03 00:05:28 -08:00
|
|
|
port: Port of the REST API interface. Default: 6333
|
|
|
|
grpc_port: Port of the gRPC interface. Default: 6334
|
|
|
|
prefer_grpc: If `true` - use gPRC interface whenever possible in custom methods.
|
|
|
|
https: If `true` - use HTTPS(SSL) protocol. Default: `false`
|
|
|
|
api_key: API key for authentication in Qdrant Cloud. Default: `None`
|
|
|
|
prefix:
|
|
|
|
If not `None` - add `prefix` to the REST URL path.
|
|
|
|
Example: `service/v1` will result in
|
|
|
|
`http://localhost:6333/service/v1/{qdrant-endpoint}` for REST API.
|
|
|
|
Default: `None`
|
|
|
|
timeout:
|
|
|
|
Timeout for REST and gRPC API requests.
|
|
|
|
Default: 5.0 seconds for REST and unlimited for gRPC
|
2023-05-04 23:06:58 -07:00
|
|
|
host: Host name of Qdrant service. If url and host are None, set to 'localhost'.
|
|
|
|
Default: `None`
|
2023-02-03 00:05:28 -08:00
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
2023-05-04 23:06:58 -07:00
|
|
|
location: Optional[str] = None,
|
|
|
|
url: Optional[str] = None,
|
|
|
|
port: Optional[int] = 6333,
|
2023-02-03 00:05:28 -08:00
|
|
|
grpc_port: int = 6334,
|
|
|
|
prefer_grpc: bool = False,
|
|
|
|
https: Optional[bool] = None,
|
|
|
|
api_key: Optional[str] = None,
|
|
|
|
prefix: Optional[str] = None,
|
|
|
|
timeout: Optional[float] = None,
|
2023-05-04 23:06:58 -07:00
|
|
|
host: Optional[str] = None,
|
|
|
|
path: Optional[str] = None,
|
2023-02-03 00:05:28 -08:00
|
|
|
):
|
|
|
|
"""Initialize with parameters."""
|
2023-05-04 23:06:58 -07:00
|
|
|
import_err_msg = (
|
|
|
|
"`qdrant-client` package not found, please run `pip install qdrant-client`"
|
|
|
|
)
|
|
|
|
try:
|
|
|
|
import qdrant_client # noqa: F401
|
|
|
|
except ImportError:
|
|
|
|
raise ImportError(import_err_msg)
|
2023-02-03 00:05:28 -08:00
|
|
|
|
|
|
|
self._client = qdrant_client.QdrantClient(
|
2023-05-04 23:06:58 -07:00
|
|
|
location=location,
|
|
|
|
url=url,
|
2023-02-03 00:05:28 -08:00
|
|
|
port=port,
|
|
|
|
grpc_port=grpc_port,
|
|
|
|
prefer_grpc=prefer_grpc,
|
|
|
|
https=https,
|
|
|
|
api_key=api_key,
|
|
|
|
prefix=prefix,
|
|
|
|
timeout=timeout,
|
2023-05-04 23:06:58 -07:00
|
|
|
host=host,
|
|
|
|
path=path,
|
2023-02-03 00:05:28 -08:00
|
|
|
)
|
|
|
|
|
|
|
|
def load_data(
|
|
|
|
self,
|
|
|
|
collection_name: str,
|
|
|
|
query_vector: List[float],
|
2023-05-04 23:06:58 -07:00
|
|
|
should_search_mapping: Optional[Dict[str, str]] = None,
|
|
|
|
must_search_mapping: Optional[Dict[str, str]] = None,
|
|
|
|
must_not_search_mapping: Optional[Dict[str, str]] = None,
|
|
|
|
rang_search_mapping: Optional[Dict[str, Dict[str, float]]] = None,
|
2023-02-03 00:05:28 -08:00
|
|
|
limit: int = 10,
|
|
|
|
) -> List[Document]:
|
|
|
|
"""Load data from Qdrant.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
collection_name (str): Name of the Qdrant collection.
|
|
|
|
query_vector (List[float]): Query vector.
|
2023-05-04 23:06:58 -07:00
|
|
|
should_search_mapping (Optional[Dict[str, str]]): Mapping from field name
|
|
|
|
to query string.
|
|
|
|
must_search_mapping (Optional[Dict[str, str]]): Mapping from field name
|
|
|
|
to query string.
|
|
|
|
must_not_search_mapping (Optional[Dict[str, str]]): Mapping from field
|
|
|
|
name to query string.
|
|
|
|
rang_search_mapping (Optional[Dict[str, Dict[str, float]]]): Mapping from
|
|
|
|
field name to range query.
|
2023-02-03 00:05:28 -08:00
|
|
|
limit (int): Number of results to return.
|
2023-05-04 23:06:58 -07:00
|
|
|
Example:
|
|
|
|
reader = QdrantReader()
|
|
|
|
reader.load_data(
|
|
|
|
collection_name="test_collection",
|
|
|
|
query_vector=[0.1, 0.2, 0.3],
|
|
|
|
should_search_mapping={"text_field": "text"},
|
|
|
|
must_search_mapping={"text_field": "text"},
|
|
|
|
must_not_search_mapping={"text_field": "text"},
|
|
|
|
# gte, lte, gt, lt supported
|
|
|
|
rang_search_mapping={"text_field": {"gte": 0.1, "lte": 0.2}},
|
|
|
|
limit=10
|
|
|
|
)
|
2023-02-03 00:05:28 -08:00
|
|
|
Returns:
|
|
|
|
List[Document]: A list of documents.
|
|
|
|
"""
|
|
|
|
from qdrant_client.http.models.models import Payload
|
2023-05-04 23:06:58 -07:00
|
|
|
from qdrant_client.http.models import (
|
|
|
|
FieldCondition,
|
|
|
|
MatchText,
|
|
|
|
MatchValue,
|
|
|
|
Range,
|
|
|
|
Filter,
|
|
|
|
)
|
2023-02-03 00:05:28 -08:00
|
|
|
|
2023-05-04 23:06:58 -07:00
|
|
|
should_search_mapping = should_search_mapping or {}
|
|
|
|
must_search_mapping = must_search_mapping or {}
|
|
|
|
must_not_search_mapping = must_not_search_mapping or {}
|
|
|
|
rang_search_mapping = rang_search_mapping or {}
|
|
|
|
|
|
|
|
should_search_conditions = [
|
|
|
|
FieldCondition(key=key, match=MatchText(text=value))
|
|
|
|
for key, value in should_search_mapping.items()
|
|
|
|
if should_search_mapping
|
|
|
|
]
|
|
|
|
must_search_conditions = [
|
|
|
|
FieldCondition(key=key, match=MatchValue(value=value))
|
|
|
|
for key, value in must_search_mapping.items()
|
|
|
|
if must_search_mapping
|
|
|
|
]
|
|
|
|
must_not_search_conditions = [
|
|
|
|
FieldCondition(key=key, match=MatchValue(value=value))
|
|
|
|
for key, value in must_not_search_mapping.items()
|
|
|
|
if must_not_search_mapping
|
|
|
|
]
|
|
|
|
rang_search_conditions = [
|
|
|
|
FieldCondition(
|
|
|
|
key=key,
|
|
|
|
range=Range(
|
|
|
|
gte=value.get("gte"),
|
|
|
|
lte=value.get("lte"),
|
|
|
|
gt=value.get("gt"),
|
|
|
|
lt=value.get("lt"),
|
|
|
|
),
|
|
|
|
)
|
|
|
|
for key, value in rang_search_mapping.items()
|
|
|
|
if rang_search_mapping
|
|
|
|
]
|
|
|
|
should_search_conditions.extend(rang_search_conditions)
|
2023-02-03 00:05:28 -08:00
|
|
|
response = self._client.search(
|
|
|
|
collection_name=collection_name,
|
|
|
|
query_vector=query_vector,
|
2023-05-04 23:06:58 -07:00
|
|
|
query_filter=Filter(
|
|
|
|
must=must_search_conditions,
|
|
|
|
must_not=must_not_search_conditions,
|
|
|
|
should=should_search_conditions,
|
|
|
|
),
|
2023-02-03 00:05:28 -08:00
|
|
|
with_vectors=True,
|
|
|
|
with_payload=True,
|
|
|
|
limit=limit,
|
|
|
|
)
|
|
|
|
|
|
|
|
documents = []
|
|
|
|
for point in response:
|
2023-02-20 00:15:44 -08:00
|
|
|
payload = cast(Payload, point.payload)
|
2023-02-03 00:05:28 -08:00
|
|
|
try:
|
|
|
|
vector = cast(List[float], point.vector)
|
|
|
|
except ValueError as e:
|
|
|
|
raise ValueError("Could not cast vector to List[float].") from e
|
|
|
|
document = Document(
|
|
|
|
doc_id=payload.get("doc_id"),
|
|
|
|
text=payload.get("text"),
|
2023-05-04 23:06:58 -07:00
|
|
|
extra_info=payload.get("extra_info"),
|
2023-02-03 00:05:28 -08:00
|
|
|
embedding=vector,
|
|
|
|
)
|
|
|
|
documents.append(document)
|
|
|
|
|
|
|
|
return documents
|