| 
									
										
										
										
											2023-06-25 16:49:14 +08:00
										 |  |  | from __future__ import annotations | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-12 12:34:01 +08:00
										 |  |  | from abc import ABC, abstractmethod | 
					
						
							|  |  |  | from typing import Any, List | 
					
						
							| 
									
										
										
										
											2023-06-25 16:49:14 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-12 12:34:01 +08:00
										 |  |  | from langchain.schema import BaseRetriever, Document | 
					
						
							| 
									
										
										
										
											2023-06-25 16:49:14 +08:00
										 |  |  | from models.dataset import Dataset | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class BaseIndex(ABC): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self, dataset: Dataset): | 
					
						
							|  |  |  |         self.dataset = dataset | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @abstractmethod | 
					
						
							|  |  |  |     def create(self, texts: list[Document], **kwargs) -> BaseIndex: | 
					
						
							|  |  |  |         raise NotImplementedError | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-18 18:15:41 +08:00
										 |  |  |     @abstractmethod | 
					
						
							|  |  |  |     def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: | 
					
						
							|  |  |  |         raise NotImplementedError | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-06-25 16:49:14 +08:00
										 |  |  |     @abstractmethod | 
					
						
							|  |  |  |     def add_texts(self, texts: list[Document], **kwargs): | 
					
						
							|  |  |  |         raise NotImplementedError | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @abstractmethod | 
					
						
							|  |  |  |     def text_exists(self, id: str) -> bool: | 
					
						
							|  |  |  |         raise NotImplementedError | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @abstractmethod | 
					
						
							|  |  |  |     def delete_by_ids(self, ids: list[str]) -> None: | 
					
						
							| 
									
										
										
										
											2023-12-18 13:10:05 +08:00
										 |  |  |         raise NotImplementedError | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @abstractmethod | 
					
						
							|  |  |  |     def delete_by_metadata_field(self, key: str, value: str) -> None: | 
					
						
							| 
									
										
										
										
											2023-06-25 16:49:14 +08:00
										 |  |  |         raise NotImplementedError | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-18 18:15:41 +08:00
										 |  |  |     @abstractmethod | 
					
						
							|  |  |  |     def delete_by_group_id(self, group_id: str) -> None: | 
					
						
							|  |  |  |         raise NotImplementedError | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-06-25 16:49:14 +08:00
										 |  |  |     @abstractmethod | 
					
						
							|  |  |  |     def delete_by_document_id(self, document_id: str): | 
					
						
							|  |  |  |         raise NotImplementedError | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @abstractmethod | 
					
						
							|  |  |  |     def get_retriever(self, **kwargs: Any) -> BaseRetriever: | 
					
						
							|  |  |  |         raise NotImplementedError | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @abstractmethod | 
					
						
							|  |  |  |     def search( | 
					
						
							|  |  |  |             self, query: str, | 
					
						
							|  |  |  |             **kwargs: Any | 
					
						
							|  |  |  |     ) -> List[Document]: | 
					
						
							|  |  |  |         raise NotImplementedError | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def delete(self) -> None: | 
					
						
							|  |  |  |         raise NotImplementedError | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: | 
					
						
							|  |  |  |         for text in texts: | 
					
						
							|  |  |  |             doc_id = text.metadata['doc_id'] | 
					
						
							|  |  |  |             exists_duplicate_node = self.text_exists(doc_id) | 
					
						
							|  |  |  |             if exists_duplicate_node: | 
					
						
							|  |  |  |                 texts.remove(text) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return texts | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _get_uuids(self, texts: list[Document]) -> list[str]: | 
					
						
							|  |  |  |         return [text.metadata['doc_id'] for text in texts] |