mirror of
https://github.com/Cinnamon/kotaemon.git
synced 2025-06-26 23:19:56 +00:00
[AUR-421] base output post-processor that works using regex. (#20)
This commit is contained in:
parent
2a3a23ecd7
commit
b794051653
@ -47,3 +47,4 @@ repos:
|
||||
rev: "v1.5.1"
|
||||
hooks:
|
||||
- id: mypy
|
||||
args: ["--check-untyped-defs", "--ignore-missing-imports"]
|
||||
|
@ -15,7 +15,7 @@ class Document(BaseDocument):
|
||||
)
|
||||
return document
|
||||
|
||||
def to_haystack_format(self) -> HaystackDocument:
|
||||
def to_haystack_format(self) -> "HaystackDocument":
|
||||
"""Convert struct to Haystack document format."""
|
||||
metadata = self.metadata or {}
|
||||
text = self.text
|
||||
|
@ -4,7 +4,7 @@ from typing import List, Type
|
||||
from langchain.schema.embeddings import Embeddings as LCEmbeddings
|
||||
from theflow import Param
|
||||
|
||||
from ..components import BaseComponent
|
||||
from ..base import BaseComponent
|
||||
from ..documents.base import Document
|
||||
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from ..components import BaseComponent
|
||||
from ..base import BaseComponent
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -4,7 +4,7 @@ from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.messages import BaseMessage, HumanMessage
|
||||
from theflow.base import Param
|
||||
|
||||
from ...components import BaseComponent
|
||||
from ...base import BaseComponent
|
||||
from ..base import LLMInterface
|
||||
|
||||
Message = TypeVar("Message", bound=BaseMessage)
|
||||
|
@ -3,7 +3,7 @@ from typing import List, Type
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from theflow.base import Param
|
||||
|
||||
from ...components import BaseComponent
|
||||
from ...base import BaseComponent
|
||||
from ..base import LLMInterface
|
||||
|
||||
|
||||
|
@ -2,7 +2,7 @@ from typing import List
|
||||
|
||||
from theflow import Node, Param
|
||||
|
||||
from ..components import BaseComponent
|
||||
from ..base import BaseComponent
|
||||
from ..documents.base import Document
|
||||
from ..embeddings import BaseEmbeddings
|
||||
from ..vectorstores import BaseVectorStore
|
||||
|
@ -2,7 +2,7 @@ from typing import List
|
||||
|
||||
from theflow import Node, Param
|
||||
|
||||
from ..components import BaseComponent
|
||||
from ..base import BaseComponent
|
||||
from ..documents.base import Document
|
||||
from ..embeddings import BaseEmbeddings
|
||||
from ..vectorstores import BaseVectorStore
|
||||
|
0
knowledgehub/post_processing/__init__.py
Normal file
0
knowledgehub/post_processing/__init__.py
Normal file
166
knowledgehub/post_processing/extractor.py
Normal file
166
knowledgehub/post_processing/extractor.py
Normal file
@ -0,0 +1,166 @@
|
||||
import re
|
||||
from typing import Dict, List
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.documents.base import Document
|
||||
|
||||
|
||||
class RegexExtractor(BaseComponent):
|
||||
"""Simple class for extracting text from a document using a regex pattern.
|
||||
|
||||
Args:
|
||||
pattern (str): The regex pattern to use.
|
||||
output_map (dict, optional): A mapping from extracted text to the
|
||||
desired output. Defaults to None.
|
||||
"""
|
||||
|
||||
pattern: str
|
||||
output_map: Dict[str, str] = {}
|
||||
|
||||
@staticmethod
|
||||
def run_raw_static(pattern: str, text: str) -> List[str]:
|
||||
"""
|
||||
Finds all non-overlapping occurrences of a pattern in a string.
|
||||
|
||||
Parameters:
|
||||
pattern (str): The regular expression pattern to search for.
|
||||
text (str): The input string to search in.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of all non-overlapping occurrences of the pattern in the
|
||||
string.
|
||||
"""
|
||||
return re.findall(pattern, text)
|
||||
|
||||
@staticmethod
|
||||
def map_output(text, output_map) -> str:
|
||||
"""
|
||||
Maps the given `text` to its corresponding value in the `output_map` dictionary.
|
||||
|
||||
Parameters:
|
||||
text (str): The input text to be mapped.
|
||||
output_map (dict): A dictionary containing mapping of input text to output
|
||||
values.
|
||||
|
||||
Returns:
|
||||
str: The corresponding value from the `output_map` if `text` is found in the
|
||||
dictionary, otherwise returns the original `text`.
|
||||
"""
|
||||
if not output_map:
|
||||
return text
|
||||
|
||||
return output_map.get(text, text)
|
||||
|
||||
def run_raw(self, text: str) -> List[str]:
|
||||
"""
|
||||
Runs the raw text through the static pattern and output mapping, returning a
|
||||
list of strings.
|
||||
|
||||
Args:
|
||||
text (str): The raw text to be processed.
|
||||
|
||||
Returns:
|
||||
List[str]: The processed output as a list of strings.
|
||||
"""
|
||||
output = self.run_raw_static(self.pattern, text)
|
||||
output = [self.map_output(text, self.output_map) for text in output]
|
||||
|
||||
return output
|
||||
|
||||
def run_batch_raw(self, text_batch: List[str]) -> List[List[str]]:
|
||||
"""
|
||||
Runs a batch of raw text inputs through the `run_raw()` method and returns the
|
||||
output for each input.
|
||||
|
||||
Parameters:
|
||||
text_batch (List[str]): A list of raw text inputs to process.
|
||||
|
||||
Returns:
|
||||
List[List[str]]: A list of lists containing the output for each input in the
|
||||
batch.
|
||||
"""
|
||||
batch_output = [self.run_raw(each_text) for each_text in text_batch]
|
||||
|
||||
return batch_output
|
||||
|
||||
def run_document(self, document: Document) -> List[Document]:
|
||||
"""
|
||||
Run the document through the regex extractor and return a list of extracted
|
||||
documents.
|
||||
|
||||
Args:
|
||||
document (Document): The input document.
|
||||
|
||||
Returns:
|
||||
List[Document]: A list of extracted documents.
|
||||
"""
|
||||
texts = self.run_raw(document.text)
|
||||
output = [
|
||||
Document(text=text, metadata={**document.metadata, "RegexExtractor": True})
|
||||
for text in texts
|
||||
]
|
||||
|
||||
return output
|
||||
|
||||
def run_batch_document(
|
||||
self, document_batch: List[Document]
|
||||
) -> List[List[Document]]:
|
||||
"""
|
||||
Runs a batch of documents through the `run_document` function and returns the
|
||||
output for each document.
|
||||
|
||||
|
||||
Parameters:
|
||||
document_batch (List[Document]): A list of Document objects representing the
|
||||
batch of documents to process.
|
||||
|
||||
Returns:
|
||||
List[List[Document]]: A list of lists where each inner list contains the
|
||||
output Document for each input Document in the batch.
|
||||
|
||||
Example:
|
||||
document1 = Document(...)
|
||||
document2 = Document(...)
|
||||
document_batch = [document1, document2]
|
||||
batch_output = self.run_batch_document(document_batch)
|
||||
# batch_output will be [[output1_document1, ...], [output1_document2, ...]]
|
||||
"""
|
||||
|
||||
batch_output = [
|
||||
self.run_document(each_document) for each_document in document_batch
|
||||
]
|
||||
|
||||
return batch_output
|
||||
|
||||
def is_document(self, text) -> bool:
|
||||
"""
|
||||
Check if the given text is an instance of the Document class.
|
||||
|
||||
Args:
|
||||
text: The text to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the text is an instance of Document, False otherwise.
|
||||
"""
|
||||
if isinstance(text, Document):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def is_batch(self, text) -> bool:
|
||||
"""
|
||||
Check if the given text is a batch of documents.
|
||||
|
||||
Parameters:
|
||||
text (List): The text to be checked.
|
||||
|
||||
Returns:
|
||||
bool: True if the text is a batch of documents, False otherwise.
|
||||
"""
|
||||
if not isinstance(text, List):
|
||||
return False
|
||||
|
||||
if len(set(self.is_document(each_text) for each_text in text)) <= 1:
|
||||
return True
|
||||
|
||||
return False
|
38
tests/test_post_processing.py
Normal file
38
tests/test_post_processing.py
Normal file
@ -0,0 +1,38 @@
|
||||
import pytest
|
||||
|
||||
from kotaemon.documents.base import Document
|
||||
from kotaemon.post_processing.extractor import RegexExtractor
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def regex_extractor():
|
||||
return RegexExtractor(
|
||||
pattern=r"\d+", output_map={"1": "One", "2": "Two", "3": "Three"}
|
||||
)
|
||||
|
||||
|
||||
def test_run_document(regex_extractor):
|
||||
document = Document(text="This is a test. 1 2 3")
|
||||
extracted_document = regex_extractor(document)
|
||||
extracted_texts = [each.text for each in extracted_document]
|
||||
assert extracted_texts == ["One", "Two", "Three"]
|
||||
|
||||
|
||||
def test_is_document(regex_extractor):
|
||||
assert regex_extractor.is_document(Document(text="Test"))
|
||||
assert not regex_extractor.is_document("Test")
|
||||
|
||||
|
||||
def test_is_batch(regex_extractor):
|
||||
assert regex_extractor.is_batch([Document(text="Test")])
|
||||
assert not regex_extractor.is_batch(Document(text="Test"))
|
||||
|
||||
|
||||
def test_run_raw(regex_extractor):
|
||||
output = regex_extractor("This is a test. 123")
|
||||
assert output == ["123"]
|
||||
|
||||
|
||||
def test_run_batch_raw(regex_extractor):
|
||||
output = regex_extractor(["This is a test. 123", "456"])
|
||||
assert output == [["123"], ["456"]]
|
Loading…
x
Reference in New Issue
Block a user