mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-27 09:04:11 +00:00
Add improvements to AzureConverter (#1896)
* Add some improvements to AzureConverter * Adapt docstring + use Path instead of str * Fix mypy version to 0.910
This commit is contained in:
parent
e4aec4661d
commit
4edec04c2c
2
.github/workflows/linux_ci.yml
vendored
2
.github/workflows/linux_ci.yml
vendored
@ -17,7 +17,7 @@ jobs:
|
|||||||
python-version: 3.8
|
python-version: 3.8
|
||||||
- name: Test with mypy
|
- name: Test with mypy
|
||||||
run: |
|
run: |
|
||||||
pip install mypy types-Markdown types-requests types-PyYAML pydantic
|
pip install mypy==0.910 types-Markdown types-requests types-PyYAML pydantic
|
||||||
mypy haystack
|
mypy haystack
|
||||||
|
|
||||||
build-cache:
|
build-cache:
|
||||||
|
@ -34,6 +34,7 @@ class AzureConverter(BaseConverter):
|
|||||||
save_json: bool = False,
|
save_json: bool = False,
|
||||||
preceding_context_len: int = 3,
|
preceding_context_len: int = 3,
|
||||||
following_context_len: int = 3,
|
following_context_len: int = 3,
|
||||||
|
merge_multiple_column_headers: bool = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
:param endpoint: Your Form Recognizer or Cognitive Services resource's endpoint.
|
:param endpoint: Your Form Recognizer or Cognitive Services resource's endpoint.
|
||||||
@ -51,11 +52,15 @@ class AzureConverter(BaseConverter):
|
|||||||
:param save_json: Whether to save the output of the Form Recognizer to a JSON file.
|
:param save_json: Whether to save the output of the Form Recognizer to a JSON file.
|
||||||
:param preceding_context_len: Number of lines before a table to extract as preceding context (will be returned as part of meta data).
|
:param preceding_context_len: Number of lines before a table to extract as preceding context (will be returned as part of meta data).
|
||||||
:param following_context_len: Number of lines after a table to extract as subsequent context (will be returned as part of meta data).
|
:param following_context_len: Number of lines after a table to extract as subsequent context (will be returned as part of meta data).
|
||||||
|
:param merge_multiple_column_headers: Some tables contain more than one row as a column header (i.e., column description).
|
||||||
|
This parameter lets you choose, whether to merge multiple column header
|
||||||
|
rows to a single row.
|
||||||
"""
|
"""
|
||||||
# save init parameters to enable export of component config as YAML
|
# save init parameters to enable export of component config as YAML
|
||||||
self.set_config(endpoint=endpoint, credential_key=credential_key, model_id=model_id,
|
self.set_config(endpoint=endpoint, credential_key=credential_key, model_id=model_id,
|
||||||
valid_languages=valid_languages, save_json=save_json,
|
valid_languages=valid_languages, save_json=save_json,
|
||||||
preceding_context_len=preceding_context_len, following_context_len=following_context_len)
|
preceding_context_len=preceding_context_len, following_context_len=following_context_len,
|
||||||
|
merge_multiple_column_headers=merge_multiple_column_headers)
|
||||||
|
|
||||||
self.document_analysis_client = DocumentAnalysisClient(endpoint=endpoint,
|
self.document_analysis_client = DocumentAnalysisClient(endpoint=endpoint,
|
||||||
credential=AzureKeyCredential(credential_key))
|
credential=AzureKeyCredential(credential_key))
|
||||||
@ -64,6 +69,7 @@ class AzureConverter(BaseConverter):
|
|||||||
self.save_json = save_json
|
self.save_json = save_json
|
||||||
self.preceding_context_len = preceding_context_len
|
self.preceding_context_len = preceding_context_len
|
||||||
self.following_context_len = following_context_len
|
self.following_context_len = following_context_len
|
||||||
|
self.merge_multiple_column_headers = merge_multiple_column_headers
|
||||||
|
|
||||||
super().__init__(valid_languages=valid_languages)
|
super().__init__(valid_languages=valid_languages)
|
||||||
|
|
||||||
@ -96,6 +102,8 @@ class AzureConverter(BaseConverter):
|
|||||||
:param known_language: Locale hint of the input document.
|
:param known_language: Locale hint of the input document.
|
||||||
See supported locales here: https://aka.ms/azsdk/formrecognizer/supportedlocales.
|
See supported locales here: https://aka.ms/azsdk/formrecognizer/supportedlocales.
|
||||||
"""
|
"""
|
||||||
|
if isinstance(file_path, str):
|
||||||
|
file_path = Path(file_path)
|
||||||
|
|
||||||
if valid_languages is None:
|
if valid_languages is None:
|
||||||
valid_languages = self.valid_languages
|
valid_languages = self.valid_languages
|
||||||
@ -105,22 +113,57 @@ class AzureConverter(BaseConverter):
|
|||||||
locale=known_language)
|
locale=known_language)
|
||||||
result = poller.result()
|
result = poller.result()
|
||||||
|
|
||||||
|
if self.save_json:
|
||||||
|
with open(file_path.with_suffix(".json"), "w") as json_file:
|
||||||
|
json.dump(result.to_dict(), json_file, indent=2)
|
||||||
|
|
||||||
|
docs = self._convert_tables_and_text(result, meta, valid_languages, file_path)
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def convert_azure_json(self,
|
||||||
|
file_path: Path,
|
||||||
|
meta: Optional[Dict[str, str]] = None,
|
||||||
|
valid_languages: Optional[List[str]] = None,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Extract text and tables from the JSON output of Azure's Form Recognizer service.
|
||||||
|
|
||||||
|
:param file_path: Path to the JSON-file you want to convert.
|
||||||
|
:param meta: Optional dictionary with metadata that shall be attached to all resulting documents.
|
||||||
|
Can be any custom keys and values.
|
||||||
|
:param valid_languages: Validate languages from a list of languages specified in the ISO 639-1
|
||||||
|
(https://en.wikipedia.org/wiki/ISO_639-1) format.
|
||||||
|
This option can be used to add test for encoding errors. If the extracted text is
|
||||||
|
not one of the valid languages, then it might likely be encoding error resulting
|
||||||
|
in garbled text.
|
||||||
|
"""
|
||||||
|
if valid_languages is None:
|
||||||
|
valid_languages = self.valid_languages
|
||||||
|
|
||||||
|
with open(file_path) as azure_file:
|
||||||
|
azure_result = json.load(azure_file)
|
||||||
|
azure_result = AnalyzeResult.from_dict(azure_result)
|
||||||
|
|
||||||
|
docs = self._convert_tables_and_text(azure_result, meta, valid_languages, file_path)
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def _convert_tables_and_text(self, result: AnalyzeResult, meta: Optional[Dict[str, str]],
|
||||||
|
valid_languages: Optional[List[str]], file_path: Path) -> List[Dict[str, Any]]:
|
||||||
tables = self._convert_tables(result, meta)
|
tables = self._convert_tables(result, meta)
|
||||||
text = self._convert_text(result, meta)
|
text = self._convert_text(result, meta)
|
||||||
docs = tables + [text]
|
docs = tables + [text]
|
||||||
|
|
||||||
if valid_languages:
|
if valid_languages:
|
||||||
file_text = text["content"] + " ".join([cell for table in tables for row in table["content"] for cell in row])
|
file_text = text["content"] + " ".join(
|
||||||
|
[cell for table in tables for row in table["content"] for cell in row])
|
||||||
if not self.validate_language(file_text):
|
if not self.validate_language(file_text):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"The language for {file_path} is not one of {self.valid_languages}. The file may not have "
|
f"The language for {file_path} is not one of {self.valid_languages}. The file may not have "
|
||||||
f"been decoded in the correct text format."
|
f"been decoded in the correct text format."
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.save_json:
|
|
||||||
with open(str(file_path) + ".json", "w") as json_file:
|
|
||||||
json.dump(result.to_dict(), json_file, indent=2)
|
|
||||||
|
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
def _convert_tables(self, result: AnalyzeResult, meta: Optional[Dict[str, str]]) -> List[Dict[str, Any]]:
|
def _convert_tables(self, result: AnalyzeResult, meta: Optional[Dict[str, str]]) -> List[Dict[str, Any]]:
|
||||||
@ -129,21 +172,37 @@ class AzureConverter(BaseConverter):
|
|||||||
for table in result.tables:
|
for table in result.tables:
|
||||||
# Initialize table with empty cells
|
# Initialize table with empty cells
|
||||||
table_list = [[""] * table.column_count for _ in range(table.row_count)]
|
table_list = [[""] * table.column_count for _ in range(table.row_count)]
|
||||||
|
additional_column_header_rows = set()
|
||||||
|
caption = ""
|
||||||
|
row_idx_start = 0
|
||||||
|
|
||||||
for cell in table.cells:
|
for idx, cell in enumerate(table.cells):
|
||||||
# Remove ':selected:'/':unselected:' tags from cell's content
|
# Remove ':selected:'/':unselected:' tags from cell's content
|
||||||
cell.content = cell.content.replace(":selected:", "")
|
cell.content = cell.content.replace(":selected:", "")
|
||||||
cell.content = cell.content.replace(":unselected:", "")
|
cell.content = cell.content.replace(":unselected:", "")
|
||||||
|
|
||||||
|
# Check if first row is a merged cell spanning whole table
|
||||||
|
# -> exclude this row and use as a caption
|
||||||
|
if idx == 0 and cell.column_span == table.column_count:
|
||||||
|
caption = cell.content
|
||||||
|
row_idx_start = 1
|
||||||
|
table_list.pop(0)
|
||||||
|
continue
|
||||||
|
|
||||||
for c in range(cell.column_span):
|
for c in range(cell.column_span):
|
||||||
for r in range(cell.row_span):
|
for r in range(cell.row_span):
|
||||||
table_list[cell.row_index + r][cell.column_index + c] = cell.content
|
if self.merge_multiple_column_headers \
|
||||||
|
and cell.kind == "columnHeader" \
|
||||||
|
and cell.row_index > row_idx_start:
|
||||||
|
# More than one row serves as column header
|
||||||
|
table_list[0][cell.column_index + c] += f"\n{cell.content}"
|
||||||
|
additional_column_header_rows.add(cell.row_index - row_idx_start)
|
||||||
|
else:
|
||||||
|
table_list[cell.row_index + r - row_idx_start][cell.column_index + c] = cell.content
|
||||||
|
|
||||||
caption = ""
|
# Remove additional column header rows, as these got attached to the first row
|
||||||
# Check if all column names are the same -> exclude these cells and use as caption
|
for row_idx in sorted(additional_column_header_rows, reverse=True):
|
||||||
if all(col_name == table_list[0][0] for col_name in table_list[0]):
|
del table_list[row_idx]
|
||||||
caption = table_list[0][0]
|
|
||||||
table_list.pop(0)
|
|
||||||
|
|
||||||
# Get preceding context of table
|
# Get preceding context of table
|
||||||
table_beginning_page = next(page for page in result.pages
|
table_beginning_page = next(page for page in result.pages
|
||||||
@ -151,7 +210,8 @@ class AzureConverter(BaseConverter):
|
|||||||
table_start_offset = table.spans[0].offset
|
table_start_offset = table.spans[0].offset
|
||||||
preceding_lines = [line.content for line in table_beginning_page.lines
|
preceding_lines = [line.content for line in table_beginning_page.lines
|
||||||
if line.spans[0].offset < table_start_offset]
|
if line.spans[0].offset < table_start_offset]
|
||||||
preceding_context = f"{caption}\n".strip() + "\n".join(preceding_lines[-self.preceding_context_len:])
|
preceding_context = "\n".join(preceding_lines[-self.preceding_context_len:]) + f"\n{caption}"
|
||||||
|
preceding_context = preceding_context.strip()
|
||||||
|
|
||||||
# Get following context
|
# Get following context
|
||||||
table_end_page = table_beginning_page if len(table.bounding_regions) == 1 else \
|
table_end_page = table_beginning_page if len(table.bounding_regions) == 1 else \
|
||||||
|
Loading…
x
Reference in New Issue
Block a user