Add improvements to AzureConverter (#1896)

* Add some improvements to AzureConverter

* Adapt docstring + use Path instead of str

* Fix mypy version to 0.910
This commit is contained in:
bogdankostic 2021-12-16 12:45:24 +01:00 committed by GitHub
parent e4aec4661d
commit 4edec04c2c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 75 additions and 15 deletions

View File

@ -17,7 +17,7 @@ jobs:
python-version: 3.8
- name: Test with mypy
run: |
pip install mypy types-Markdown types-requests types-PyYAML pydantic
pip install mypy==0.910 types-Markdown types-requests types-PyYAML pydantic
mypy haystack
build-cache:

View File

@ -34,6 +34,7 @@ class AzureConverter(BaseConverter):
save_json: bool = False,
preceding_context_len: int = 3,
following_context_len: int = 3,
merge_multiple_column_headers: bool = True,
):
"""
:param endpoint: Your Form Recognizer or Cognitive Services resource's endpoint.
@ -51,11 +52,15 @@ class AzureConverter(BaseConverter):
:param save_json: Whether to save the output of the Form Recognizer to a JSON file.
:param preceding_context_len: Number of lines before a table to extract as preceding context (will be returned as part of meta data).
:param following_context_len: Number of lines after a table to extract as subsequent context (will be returned as part of meta data).
:param merge_multiple_column_headers: Some tables contain more than one row as a column header (i.e., column description).
This parameter lets you choose, whether to merge multiple column header
rows to a single row.
"""
# save init parameters to enable export of component config as YAML
self.set_config(endpoint=endpoint, credential_key=credential_key, model_id=model_id,
valid_languages=valid_languages, save_json=save_json,
preceding_context_len=preceding_context_len, following_context_len=following_context_len)
preceding_context_len=preceding_context_len, following_context_len=following_context_len,
merge_multiple_column_headers=merge_multiple_column_headers)
self.document_analysis_client = DocumentAnalysisClient(endpoint=endpoint,
credential=AzureKeyCredential(credential_key))
@ -64,6 +69,7 @@ class AzureConverter(BaseConverter):
self.save_json = save_json
self.preceding_context_len = preceding_context_len
self.following_context_len = following_context_len
self.merge_multiple_column_headers = merge_multiple_column_headers
super().__init__(valid_languages=valid_languages)
@ -96,6 +102,8 @@ class AzureConverter(BaseConverter):
:param known_language: Locale hint of the input document.
See supported locales here: https://aka.ms/azsdk/formrecognizer/supportedlocales.
"""
if isinstance(file_path, str):
file_path = Path(file_path)
if valid_languages is None:
valid_languages = self.valid_languages
@ -105,22 +113,57 @@ class AzureConverter(BaseConverter):
locale=known_language)
result = poller.result()
if self.save_json:
with open(file_path.with_suffix(".json"), "w") as json_file:
json.dump(result.to_dict(), json_file, indent=2)
docs = self._convert_tables_and_text(result, meta, valid_languages, file_path)
return docs
def convert_azure_json(self,
file_path: Path,
meta: Optional[Dict[str, str]] = None,
valid_languages: Optional[List[str]] = None,
) -> List[Dict[str, Any]]:
"""
Extract text and tables from the JSON output of Azure's Form Recognizer service.
:param file_path: Path to the JSON-file you want to convert.
:param meta: Optional dictionary with metadata that shall be attached to all resulting documents.
Can be any custom keys and values.
:param valid_languages: Validate languages from a list of languages specified in the ISO 639-1
(https://en.wikipedia.org/wiki/ISO_639-1) format.
This option can be used to add test for encoding errors. If the extracted text is
not one of the valid languages, then it might likely be encoding error resulting
in garbled text.
"""
if valid_languages is None:
valid_languages = self.valid_languages
with open(file_path) as azure_file:
azure_result = json.load(azure_file)
azure_result = AnalyzeResult.from_dict(azure_result)
docs = self._convert_tables_and_text(azure_result, meta, valid_languages, file_path)
return docs
def _convert_tables_and_text(self, result: AnalyzeResult, meta: Optional[Dict[str, str]],
valid_languages: Optional[List[str]], file_path: Path) -> List[Dict[str, Any]]:
tables = self._convert_tables(result, meta)
text = self._convert_text(result, meta)
docs = tables + [text]
if valid_languages:
file_text = text["content"] + " ".join([cell for table in tables for row in table["content"] for cell in row])
file_text = text["content"] + " ".join(
[cell for table in tables for row in table["content"] for cell in row])
if not self.validate_language(file_text):
logger.warning(
f"The language for {file_path} is not one of {self.valid_languages}. The file may not have "
f"been decoded in the correct text format."
)
if self.save_json:
with open(str(file_path) + ".json", "w") as json_file:
json.dump(result.to_dict(), json_file, indent=2)
return docs
def _convert_tables(self, result: AnalyzeResult, meta: Optional[Dict[str, str]]) -> List[Dict[str, Any]]:
@ -129,21 +172,37 @@ class AzureConverter(BaseConverter):
for table in result.tables:
# Initialize table with empty cells
table_list = [[""] * table.column_count for _ in range(table.row_count)]
additional_column_header_rows = set()
caption = ""
row_idx_start = 0
for cell in table.cells:
for idx, cell in enumerate(table.cells):
# Remove ':selected:'/':unselected:' tags from cell's content
cell.content = cell.content.replace(":selected:", "")
cell.content = cell.content.replace(":unselected:", "")
# Check if first row is a merged cell spanning whole table
# -> exclude this row and use as a caption
if idx == 0 and cell.column_span == table.column_count:
caption = cell.content
row_idx_start = 1
table_list.pop(0)
continue
for c in range(cell.column_span):
for r in range(cell.row_span):
table_list[cell.row_index + r][cell.column_index + c] = cell.content
if self.merge_multiple_column_headers \
and cell.kind == "columnHeader" \
and cell.row_index > row_idx_start:
# More than one row serves as column header
table_list[0][cell.column_index + c] += f"\n{cell.content}"
additional_column_header_rows.add(cell.row_index - row_idx_start)
else:
table_list[cell.row_index + r - row_idx_start][cell.column_index + c] = cell.content
caption = ""
# Check if all column names are the same -> exclude these cells and use as caption
if all(col_name == table_list[0][0] for col_name in table_list[0]):
caption = table_list[0][0]
table_list.pop(0)
# Remove additional column header rows, as these got attached to the first row
for row_idx in sorted(additional_column_header_rows, reverse=True):
del table_list[row_idx]
# Get preceding context of table
table_beginning_page = next(page for page in result.pages
@ -151,7 +210,8 @@ class AzureConverter(BaseConverter):
table_start_offset = table.spans[0].offset
preceding_lines = [line.content for line in table_beginning_page.lines
if line.spans[0].offset < table_start_offset]
preceding_context = f"{caption}\n".strip() + "\n".join(preceding_lines[-self.preceding_context_len:])
preceding_context = "\n".join(preceding_lines[-self.preceding_context_len:]) + f"\n{caption}"
preceding_context = preceding_context.strip()
# Get following context
table_end_page = table_beginning_page if len(table.bounding_regions) == 1 else \