mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-25 16:15:35 +00:00
Add improvements to AzureConverter (#1896)
* Add some improvements to AzureConverter * Adapt docstring + use Path instead of str * Fix mypy version to 0.910
This commit is contained in:
parent
e4aec4661d
commit
4edec04c2c
2
.github/workflows/linux_ci.yml
vendored
2
.github/workflows/linux_ci.yml
vendored
@ -17,7 +17,7 @@ jobs:
|
||||
python-version: 3.8
|
||||
- name: Test with mypy
|
||||
run: |
|
||||
pip install mypy types-Markdown types-requests types-PyYAML pydantic
|
||||
pip install mypy==0.910 types-Markdown types-requests types-PyYAML pydantic
|
||||
mypy haystack
|
||||
|
||||
build-cache:
|
||||
|
@ -34,6 +34,7 @@ class AzureConverter(BaseConverter):
|
||||
save_json: bool = False,
|
||||
preceding_context_len: int = 3,
|
||||
following_context_len: int = 3,
|
||||
merge_multiple_column_headers: bool = True,
|
||||
):
|
||||
"""
|
||||
:param endpoint: Your Form Recognizer or Cognitive Services resource's endpoint.
|
||||
@ -51,11 +52,15 @@ class AzureConverter(BaseConverter):
|
||||
:param save_json: Whether to save the output of the Form Recognizer to a JSON file.
|
||||
:param preceding_context_len: Number of lines before a table to extract as preceding context (will be returned as part of meta data).
|
||||
:param following_context_len: Number of lines after a table to extract as subsequent context (will be returned as part of meta data).
|
||||
:param merge_multiple_column_headers: Some tables contain more than one row as a column header (i.e., column description).
|
||||
This parameter lets you choose, whether to merge multiple column header
|
||||
rows to a single row.
|
||||
"""
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(endpoint=endpoint, credential_key=credential_key, model_id=model_id,
|
||||
valid_languages=valid_languages, save_json=save_json,
|
||||
preceding_context_len=preceding_context_len, following_context_len=following_context_len)
|
||||
preceding_context_len=preceding_context_len, following_context_len=following_context_len,
|
||||
merge_multiple_column_headers=merge_multiple_column_headers)
|
||||
|
||||
self.document_analysis_client = DocumentAnalysisClient(endpoint=endpoint,
|
||||
credential=AzureKeyCredential(credential_key))
|
||||
@ -64,6 +69,7 @@ class AzureConverter(BaseConverter):
|
||||
self.save_json = save_json
|
||||
self.preceding_context_len = preceding_context_len
|
||||
self.following_context_len = following_context_len
|
||||
self.merge_multiple_column_headers = merge_multiple_column_headers
|
||||
|
||||
super().__init__(valid_languages=valid_languages)
|
||||
|
||||
@ -96,6 +102,8 @@ class AzureConverter(BaseConverter):
|
||||
:param known_language: Locale hint of the input document.
|
||||
See supported locales here: https://aka.ms/azsdk/formrecognizer/supportedlocales.
|
||||
"""
|
||||
if isinstance(file_path, str):
|
||||
file_path = Path(file_path)
|
||||
|
||||
if valid_languages is None:
|
||||
valid_languages = self.valid_languages
|
||||
@ -105,22 +113,57 @@ class AzureConverter(BaseConverter):
|
||||
locale=known_language)
|
||||
result = poller.result()
|
||||
|
||||
if self.save_json:
|
||||
with open(file_path.with_suffix(".json"), "w") as json_file:
|
||||
json.dump(result.to_dict(), json_file, indent=2)
|
||||
|
||||
docs = self._convert_tables_and_text(result, meta, valid_languages, file_path)
|
||||
|
||||
return docs
|
||||
|
||||
def convert_azure_json(self,
|
||||
file_path: Path,
|
||||
meta: Optional[Dict[str, str]] = None,
|
||||
valid_languages: Optional[List[str]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Extract text and tables from the JSON output of Azure's Form Recognizer service.
|
||||
|
||||
:param file_path: Path to the JSON-file you want to convert.
|
||||
:param meta: Optional dictionary with metadata that shall be attached to all resulting documents.
|
||||
Can be any custom keys and values.
|
||||
:param valid_languages: Validate languages from a list of languages specified in the ISO 639-1
|
||||
(https://en.wikipedia.org/wiki/ISO_639-1) format.
|
||||
This option can be used to add test for encoding errors. If the extracted text is
|
||||
not one of the valid languages, then it might likely be encoding error resulting
|
||||
in garbled text.
|
||||
"""
|
||||
if valid_languages is None:
|
||||
valid_languages = self.valid_languages
|
||||
|
||||
with open(file_path) as azure_file:
|
||||
azure_result = json.load(azure_file)
|
||||
azure_result = AnalyzeResult.from_dict(azure_result)
|
||||
|
||||
docs = self._convert_tables_and_text(azure_result, meta, valid_languages, file_path)
|
||||
|
||||
return docs
|
||||
|
||||
def _convert_tables_and_text(self, result: AnalyzeResult, meta: Optional[Dict[str, str]],
|
||||
valid_languages: Optional[List[str]], file_path: Path) -> List[Dict[str, Any]]:
|
||||
tables = self._convert_tables(result, meta)
|
||||
text = self._convert_text(result, meta)
|
||||
docs = tables + [text]
|
||||
|
||||
if valid_languages:
|
||||
file_text = text["content"] + " ".join([cell for table in tables for row in table["content"] for cell in row])
|
||||
file_text = text["content"] + " ".join(
|
||||
[cell for table in tables for row in table["content"] for cell in row])
|
||||
if not self.validate_language(file_text):
|
||||
logger.warning(
|
||||
f"The language for {file_path} is not one of {self.valid_languages}. The file may not have "
|
||||
f"been decoded in the correct text format."
|
||||
)
|
||||
|
||||
if self.save_json:
|
||||
with open(str(file_path) + ".json", "w") as json_file:
|
||||
json.dump(result.to_dict(), json_file, indent=2)
|
||||
|
||||
return docs
|
||||
|
||||
def _convert_tables(self, result: AnalyzeResult, meta: Optional[Dict[str, str]]) -> List[Dict[str, Any]]:
|
||||
@ -129,21 +172,37 @@ class AzureConverter(BaseConverter):
|
||||
for table in result.tables:
|
||||
# Initialize table with empty cells
|
||||
table_list = [[""] * table.column_count for _ in range(table.row_count)]
|
||||
additional_column_header_rows = set()
|
||||
caption = ""
|
||||
row_idx_start = 0
|
||||
|
||||
for cell in table.cells:
|
||||
for idx, cell in enumerate(table.cells):
|
||||
# Remove ':selected:'/':unselected:' tags from cell's content
|
||||
cell.content = cell.content.replace(":selected:", "")
|
||||
cell.content = cell.content.replace(":unselected:", "")
|
||||
|
||||
# Check if first row is a merged cell spanning whole table
|
||||
# -> exclude this row and use as a caption
|
||||
if idx == 0 and cell.column_span == table.column_count:
|
||||
caption = cell.content
|
||||
row_idx_start = 1
|
||||
table_list.pop(0)
|
||||
continue
|
||||
|
||||
for c in range(cell.column_span):
|
||||
for r in range(cell.row_span):
|
||||
table_list[cell.row_index + r][cell.column_index + c] = cell.content
|
||||
if self.merge_multiple_column_headers \
|
||||
and cell.kind == "columnHeader" \
|
||||
and cell.row_index > row_idx_start:
|
||||
# More than one row serves as column header
|
||||
table_list[0][cell.column_index + c] += f"\n{cell.content}"
|
||||
additional_column_header_rows.add(cell.row_index - row_idx_start)
|
||||
else:
|
||||
table_list[cell.row_index + r - row_idx_start][cell.column_index + c] = cell.content
|
||||
|
||||
caption = ""
|
||||
# Check if all column names are the same -> exclude these cells and use as caption
|
||||
if all(col_name == table_list[0][0] for col_name in table_list[0]):
|
||||
caption = table_list[0][0]
|
||||
table_list.pop(0)
|
||||
# Remove additional column header rows, as these got attached to the first row
|
||||
for row_idx in sorted(additional_column_header_rows, reverse=True):
|
||||
del table_list[row_idx]
|
||||
|
||||
# Get preceding context of table
|
||||
table_beginning_page = next(page for page in result.pages
|
||||
@ -151,7 +210,8 @@ class AzureConverter(BaseConverter):
|
||||
table_start_offset = table.spans[0].offset
|
||||
preceding_lines = [line.content for line in table_beginning_page.lines
|
||||
if line.spans[0].offset < table_start_offset]
|
||||
preceding_context = f"{caption}\n".strip() + "\n".join(preceding_lines[-self.preceding_context_len:])
|
||||
preceding_context = "\n".join(preceding_lines[-self.preceding_context_len:]) + f"\n{caption}"
|
||||
preceding_context = preceding_context.strip()
|
||||
|
||||
# Get following context
|
||||
table_end_page = table_beginning_page if len(table.bounding_regions) == 1 else \
|
||||
|
Loading…
x
Reference in New Issue
Block a user