188 lines
6.8 KiB
Python
Raw Normal View History

2023-02-03 00:05:28 -08:00
"""Qdrant reader."""
2023-05-04 23:06:58 -07:00
from typing import List, Optional, cast, Dict
2023-02-03 00:05:28 -08:00
from llama_index.readers.base import BaseReader
from llama_index.readers.schema.base import Document
2023-02-03 00:05:28 -08:00
class QdrantReader(BaseReader):
"""Qdrant reader.
Retrieve documents from existing Qdrant collections.
Args:
2023-05-04 23:06:58 -07:00
location:
If `:memory:` - use in-memory Qdrant instance.
If `str` - use it as a `url` parameter.
If `None` - use default values for `host` and `port`.
url:
either host or str of
"Optional[scheme], host, Optional[port], Optional[prefix]".
Default: `None`
2023-02-03 00:05:28 -08:00
port: Port of the REST API interface. Default: 6333
grpc_port: Port of the gRPC interface. Default: 6334
prefer_grpc: If `true` - use gPRC interface whenever possible in custom methods.
https: If `true` - use HTTPS(SSL) protocol. Default: `false`
api_key: API key for authentication in Qdrant Cloud. Default: `None`
prefix:
If not `None` - add `prefix` to the REST URL path.
Example: `service/v1` will result in
`http://localhost:6333/service/v1/{qdrant-endpoint}` for REST API.
Default: `None`
timeout:
Timeout for REST and gRPC API requests.
Default: 5.0 seconds for REST and unlimited for gRPC
2023-05-04 23:06:58 -07:00
host: Host name of Qdrant service. If url and host are None, set to 'localhost'.
Default: `None`
2023-02-03 00:05:28 -08:00
"""
def __init__(
self,
2023-05-04 23:06:58 -07:00
location: Optional[str] = None,
url: Optional[str] = None,
port: Optional[int] = 6333,
2023-02-03 00:05:28 -08:00
grpc_port: int = 6334,
prefer_grpc: bool = False,
https: Optional[bool] = None,
api_key: Optional[str] = None,
prefix: Optional[str] = None,
timeout: Optional[float] = None,
2023-05-04 23:06:58 -07:00
host: Optional[str] = None,
path: Optional[str] = None,
2023-02-03 00:05:28 -08:00
):
"""Initialize with parameters."""
2023-05-04 23:06:58 -07:00
import_err_msg = (
"`qdrant-client` package not found, please run `pip install qdrant-client`"
)
try:
import qdrant_client # noqa: F401
except ImportError:
raise ImportError(import_err_msg)
2023-02-03 00:05:28 -08:00
self._client = qdrant_client.QdrantClient(
2023-05-04 23:06:58 -07:00
location=location,
url=url,
2023-02-03 00:05:28 -08:00
port=port,
grpc_port=grpc_port,
prefer_grpc=prefer_grpc,
https=https,
api_key=api_key,
prefix=prefix,
timeout=timeout,
2023-05-04 23:06:58 -07:00
host=host,
path=path,
2023-02-03 00:05:28 -08:00
)
def load_data(
self,
collection_name: str,
query_vector: List[float],
2023-05-04 23:06:58 -07:00
should_search_mapping: Optional[Dict[str, str]] = None,
must_search_mapping: Optional[Dict[str, str]] = None,
must_not_search_mapping: Optional[Dict[str, str]] = None,
rang_search_mapping: Optional[Dict[str, Dict[str, float]]] = None,
2023-02-03 00:05:28 -08:00
limit: int = 10,
) -> List[Document]:
"""Load data from Qdrant.
Args:
collection_name (str): Name of the Qdrant collection.
query_vector (List[float]): Query vector.
2023-05-04 23:06:58 -07:00
should_search_mapping (Optional[Dict[str, str]]): Mapping from field name
to query string.
must_search_mapping (Optional[Dict[str, str]]): Mapping from field name
to query string.
must_not_search_mapping (Optional[Dict[str, str]]): Mapping from field
name to query string.
rang_search_mapping (Optional[Dict[str, Dict[str, float]]]): Mapping from
field name to range query.
2023-02-03 00:05:28 -08:00
limit (int): Number of results to return.
2023-05-04 23:06:58 -07:00
Example:
reader = QdrantReader()
reader.load_data(
collection_name="test_collection",
query_vector=[0.1, 0.2, 0.3],
should_search_mapping={"text_field": "text"},
must_search_mapping={"text_field": "text"},
must_not_search_mapping={"text_field": "text"},
# gte, lte, gt, lt supported
rang_search_mapping={"text_field": {"gte": 0.1, "lte": 0.2}},
limit=10
)
2023-02-03 00:05:28 -08:00
Returns:
List[Document]: A list of documents.
"""
from qdrant_client.http.models.models import Payload
2023-05-04 23:06:58 -07:00
from qdrant_client.http.models import (
FieldCondition,
MatchText,
MatchValue,
Range,
Filter,
)
2023-02-03 00:05:28 -08:00
2023-05-04 23:06:58 -07:00
should_search_mapping = should_search_mapping or {}
must_search_mapping = must_search_mapping or {}
must_not_search_mapping = must_not_search_mapping or {}
rang_search_mapping = rang_search_mapping or {}
should_search_conditions = [
FieldCondition(key=key, match=MatchText(text=value))
for key, value in should_search_mapping.items()
if should_search_mapping
]
must_search_conditions = [
FieldCondition(key=key, match=MatchValue(value=value))
for key, value in must_search_mapping.items()
if must_search_mapping
]
must_not_search_conditions = [
FieldCondition(key=key, match=MatchValue(value=value))
for key, value in must_not_search_mapping.items()
if must_not_search_mapping
]
rang_search_conditions = [
FieldCondition(
key=key,
range=Range(
gte=value.get("gte"),
lte=value.get("lte"),
gt=value.get("gt"),
lt=value.get("lt"),
),
)
for key, value in rang_search_mapping.items()
if rang_search_mapping
]
should_search_conditions.extend(rang_search_conditions)
2023-02-03 00:05:28 -08:00
response = self._client.search(
collection_name=collection_name,
query_vector=query_vector,
2023-05-04 23:06:58 -07:00
query_filter=Filter(
must=must_search_conditions,
must_not=must_not_search_conditions,
should=should_search_conditions,
),
2023-02-03 00:05:28 -08:00
with_vectors=True,
with_payload=True,
limit=limit,
)
documents = []
for point in response:
payload = cast(Payload, point.payload)
2023-02-03 00:05:28 -08:00
try:
vector = cast(List[float], point.vector)
except ValueError as e:
raise ValueError("Could not cast vector to List[float].") from e
document = Document(
doc_id=payload.get("doc_id"),
text=payload.get("text"),
2023-05-04 23:06:58 -07:00
extra_info=payload.get("extra_info"),
2023-02-03 00:05:28 -08:00
embedding=vector,
)
documents.append(document)
return documents