# Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License """Search module.""" import json import re from dataclasses import dataclass import pandas as pd import streamlit as st from rag.typing import SearchResult, SearchType from streamlit.delta_generator import DeltaGenerator def init_search_ui( container: DeltaGenerator, search_type: SearchType, title: str, caption: str ): """Initialize search UI component.""" with container: st.markdown(title) st.caption(caption) ui_tag = search_type.value.lower() st.session_state[f"{ui_tag}_response_placeholder"] = st.empty() st.session_state[f"{ui_tag}_context_placeholder"] = st.empty() st.session_state[f"{ui_tag}_container"] = container @dataclass class SearchStats: """SearchStats class definition.""" completion_time: float llm_calls: int prompt_tokens: int def display_search_result( container: DeltaGenerator, result: SearchResult, stats: SearchStats | None = None ): """Display search results data into the UI.""" response_placeholder_attr = ( result.search_type.value.lower() + "_response_placeholder" ) with container: # display response response = format_response_hyperlinks( result.response, result.search_type.value.lower() ) if stats is not None and stats.completion_time is not None: st.markdown( f"*{stats.prompt_tokens:,} tokens used, {stats.llm_calls} LLM calls, {int(stats.completion_time)} seconds elapsed.*" ) st.session_state[response_placeholder_attr] = st.markdown( f"
{response}
", unsafe_allow_html=True, ) def display_citations( container: DeltaGenerator | None = None, result: SearchResult | None = None ): """Display citations into the UI.""" if container is not None: with container: # display context used for generating the response if result is not None: context_data = result.context context_data = dict(sorted(context_data.items())) st.markdown("---") st.markdown("### Citations") for key, value in context_data.items(): if len(value) > 0: key_type = key if key == "sources": st.markdown( f"Relevant chunks of source documents **({len(value)})**:" ) key_type = "sources" elif key == "reports": st.markdown( f"Relevant AI-generated network reports **({len(value)})**:" ) else: st.markdown( f"Relevant AI-extracted {key} **({len(value)})**:" ) st.markdown( render_html_table( value, result.search_type.value.lower(), key_type ), unsafe_allow_html=True, ) def format_response_hyperlinks(str_response: str, search_type: str = ""): """Format response to show hyperlinks inside the response UI.""" results_with_hyperlinks = format_response_hyperlinks_by_key( str_response, "Entities", "Entities", search_type ) results_with_hyperlinks = format_response_hyperlinks_by_key( results_with_hyperlinks, "Sources", "Sources", search_type ) results_with_hyperlinks = format_response_hyperlinks_by_key( results_with_hyperlinks, "Documents", "Sources", search_type ) results_with_hyperlinks = format_response_hyperlinks_by_key( results_with_hyperlinks, "Relationships", "Relationships", search_type ) results_with_hyperlinks = format_response_hyperlinks_by_key( results_with_hyperlinks, "Reports", "Reports", search_type ) return results_with_hyperlinks # noqa: RET504 def format_response_hyperlinks_by_key( str_response: str, key: str, anchor: str, search_type: str = "" ): """Format response to show hyperlinks inside the response UI by key.""" pattern = r"\(\d+(?:,\s*\d+)*(?:,\s*\+more)?\)" citations_list = re.findall(f"{key} {pattern}", str_response) results_with_hyperlinks = str_response if len(citations_list) > 0: for occurrence in citations_list: string_occurrence = str(occurrence) numbers_list = string_occurrence[ string_occurrence.find("(") + 1 : string_occurrence.find(")") ].split(",") string_occurrence_hyperlinks = string_occurrence for number in numbers_list: if number.lower().strip() != "+more": string_occurrence_hyperlinks = string_occurrence_hyperlinks.replace( number, f'{number}', ) results_with_hyperlinks = results_with_hyperlinks.replace( occurrence, string_occurrence_hyperlinks ) return results_with_hyperlinks def format_suggested_questions(questions: str): """Format suggested questions to the UI.""" citations_pattern = r"\[.*?\]" substring = re.sub(citations_pattern, "", questions).strip() return convert_numbered_list_to_array(substring) def convert_numbered_list_to_array(numbered_list_str): """Convert numbered list result into an array of elements.""" lines = numbered_list_str.strip().split("\n") items = [] for line in lines: match = re.match(r"^\d+\.\s*(.*)", line) if match: item = match.group(1).strip() items.append(item) return items def get_ids_per_key(str_response: str, key: str): """Filter ids per key.""" pattern = r"\(\d+(?:,\s*\d+)*(?:,\s*\+more)?\)" citations_list = re.findall(f"{key} {pattern}", str_response) numbers_list = [] if len(citations_list) > 0: for occurrence in citations_list: string_occurrence = str(occurrence) numbers_list = string_occurrence[ string_occurrence.find("(") + 1 : string_occurrence.find(")") ].split(",") return numbers_list SHORT_WORDS = 12 LONG_WORDS = 200 # Function to generate HTML table with ids def render_html_table(df: pd.DataFrame, search_type: str, key: str): """Render HTML table into the UI.""" table_container = """ max-width: 100%; overflow: hidden; margin: 0 auto; """ table_style = """ width: 100%; border-collapse: collapse; table-layout: fixed; """ th_style = """ word-wrap: break-word; white-space: normal; """ td_style = """ border: 1px solid #efefef; word-wrap: break-word; white-space: normal; """ table_html = f'
' table_html += f'' table_html += "" for col in pd.DataFrame(df).columns: table_html += f'' table_html += "" table_html += "" for index, row in pd.DataFrame(df).iterrows(): html_id = ( f"{search_type.lower().strip()}-{key.lower().strip()}-{row.id.strip()}" if "id" in row else f"row-{index}" ) table_html += f'' for value in row: if isinstance(value, str): if value[0:1] == "{": value_casted = json.loads(value) value = value_casted["summary"] value_array = str(value).split(" ") td_value = ( " ".join(value_array[:SHORT_WORDS]) + "..." if len(value_array) >= SHORT_WORDS else value ) title_value = ( " ".join(value_array[:LONG_WORDS]) + "..." if len(value_array) >= LONG_WORDS else value ) title_value = ( title_value.replace('"', """) .replace("'", "'") .replace("\n", " ") .replace("\n\n", " ") .replace("\r\n", " ") ) table_html += ( f'' ) else: table_html += f'' table_html += "" table_html += "
{col}
{td_value}{value}
" return table_html def display_graph_citations( entities: pd.DataFrame, relationships: pd.DataFrame, citation_type: str ): """Display graph citations into the UI.""" st.markdown("---") st.markdown("### Citations") st.markdown(f"Relevant AI-extracted entities **({len(entities)})**:") st.markdown( render_html_table(entities, citation_type, "entities"), unsafe_allow_html=True, ) st.markdown(f"Relevant AI-extracted relationships **({len(relationships)})**:") st.markdown( render_html_table(relationships, citation_type, "relationships"), unsafe_allow_html=True, )