mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-27 06:59:03 +00:00
Add bs4 and overlap (#2271)
Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
parent
0d99d45b0f
commit
6b1376b04d
@ -1,10 +1,13 @@
|
||||
import glob
|
||||
import os
|
||||
import re
|
||||
from typing import Callable, List, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import chromadb
|
||||
import markdownify
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
if chromadb.__version__ < "0.4.15":
|
||||
from chromadb.api import API
|
||||
@ -61,6 +64,7 @@ if HAS_UNSTRUCTURED:
|
||||
TEXT_FORMATS += UNSTRUCTURED_FORMATS
|
||||
TEXT_FORMATS = list(set(TEXT_FORMATS))
|
||||
VALID_CHUNK_MODES = frozenset({"one_line", "multi_lines"})
|
||||
RAG_MINIMUM_MESSAGE_LENGTH = int(os.environ.get("RAG_MINIMUM_MESSAGE_LENGTH", 5))
|
||||
|
||||
|
||||
def split_text_to_chunks(
|
||||
@ -68,22 +72,27 @@ def split_text_to_chunks(
|
||||
max_tokens: int = 4000,
|
||||
chunk_mode: str = "multi_lines",
|
||||
must_break_at_empty_line: bool = True,
|
||||
overlap: int = 10,
|
||||
overlap: int = 0, # number of overlapping lines
|
||||
):
|
||||
"""Split a long text into chunks of max_tokens."""
|
||||
if chunk_mode not in VALID_CHUNK_MODES:
|
||||
raise AssertionError
|
||||
if chunk_mode == "one_line":
|
||||
must_break_at_empty_line = False
|
||||
overlap = 0
|
||||
chunks = []
|
||||
lines = text.split("\n")
|
||||
num_lines = len(lines)
|
||||
if num_lines < 3 and must_break_at_empty_line:
|
||||
logger.warning("The input text has less than 3 lines. Set `must_break_at_empty_line` to `False`")
|
||||
must_break_at_empty_line = False
|
||||
lines_tokens = [count_token(line) for line in lines]
|
||||
sum_tokens = sum(lines_tokens)
|
||||
while sum_tokens > max_tokens:
|
||||
if chunk_mode == "one_line":
|
||||
estimated_line_cut = 2
|
||||
else:
|
||||
estimated_line_cut = int(max_tokens / sum_tokens * len(lines)) + 1
|
||||
estimated_line_cut = max(int(max_tokens / sum_tokens * len(lines)), 2)
|
||||
cnt = 0
|
||||
prev = ""
|
||||
for cnt in reversed(range(estimated_line_cut)):
|
||||
@ -97,19 +106,25 @@ def split_text_to_chunks(
|
||||
f"max_tokens is too small to fit a single line of text. Breaking this line:\n\t{lines[0][:100]} ..."
|
||||
)
|
||||
if not must_break_at_empty_line:
|
||||
split_len = int(max_tokens / lines_tokens[0] * 0.9 * len(lines[0]))
|
||||
split_len = max(
|
||||
int(max_tokens / (lines_tokens[0] * 0.9 * len(lines[0]) + 0.1)), RAG_MINIMUM_MESSAGE_LENGTH
|
||||
)
|
||||
prev = lines[0][:split_len]
|
||||
lines[0] = lines[0][split_len:]
|
||||
lines_tokens[0] = count_token(lines[0])
|
||||
else:
|
||||
logger.warning("Failed to split docs with must_break_at_empty_line being True, set to False.")
|
||||
must_break_at_empty_line = False
|
||||
chunks.append(prev) if len(prev) > 10 else None # don't add chunks less than 10 characters
|
||||
lines = lines[cnt:]
|
||||
lines_tokens = lines_tokens[cnt:]
|
||||
(
|
||||
chunks.append(prev) if len(prev) >= RAG_MINIMUM_MESSAGE_LENGTH else None
|
||||
) # don't add chunks less than RAG_MINIMUM_MESSAGE_LENGTH characters
|
||||
lines = lines[cnt - overlap if cnt > overlap else cnt :]
|
||||
lines_tokens = lines_tokens[cnt - overlap if cnt > overlap else cnt :]
|
||||
sum_tokens = sum(lines_tokens)
|
||||
text_to_chunk = "\n".join(lines)
|
||||
chunks.append(text_to_chunk) if len(text_to_chunk) > 10 else None # don't add chunks less than 10 characters
|
||||
text_to_chunk = "\n".join(lines).strip()
|
||||
(
|
||||
chunks.append(text_to_chunk) if len(text_to_chunk) >= RAG_MINIMUM_MESSAGE_LENGTH else None
|
||||
) # don't add chunks less than RAG_MINIMUM_MESSAGE_LENGTH characters
|
||||
return chunks
|
||||
|
||||
|
||||
@ -185,7 +200,9 @@ def get_files_from_dir(dir_path: Union[str, List[str]], types: list = TEXT_FORMA
|
||||
if os.path.isfile(item):
|
||||
files.append(item)
|
||||
elif is_url(item):
|
||||
files.append(get_file_from_url(item))
|
||||
filepath = get_file_from_url(item)
|
||||
if filepath:
|
||||
files.append(filepath)
|
||||
elif os.path.exists(item):
|
||||
try:
|
||||
files.extend(get_files_from_dir(item, types, recursive))
|
||||
@ -201,7 +218,11 @@ def get_files_from_dir(dir_path: Union[str, List[str]], types: list = TEXT_FORMA
|
||||
|
||||
# If the path is a url, download it and return the downloaded file
|
||||
if is_url(dir_path):
|
||||
return [get_file_from_url(dir_path)]
|
||||
filepath = get_file_from_url(dir_path)
|
||||
if filepath:
|
||||
return [filepath]
|
||||
else:
|
||||
return []
|
||||
|
||||
if os.path.exists(dir_path):
|
||||
for type in types:
|
||||
@ -215,17 +236,72 @@ def get_files_from_dir(dir_path: Union[str, List[str]], types: list = TEXT_FORMA
|
||||
return files
|
||||
|
||||
|
||||
def parse_html_to_markdown(html: str, url: str = None) -> str:
|
||||
"""Parse HTML to markdown."""
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
title = soup.title.string
|
||||
# Remove javascript and style blocks
|
||||
for script in soup(["script", "style"]):
|
||||
script.extract()
|
||||
|
||||
# Convert to markdown -- Wikipedia gets special attention to get a clean version of the page
|
||||
if isinstance(url, str) and url.startswith("https://en.wikipedia.org/"):
|
||||
body_elm = soup.find("div", {"id": "mw-content-text"})
|
||||
title_elm = soup.find("span", {"class": "mw-page-title-main"})
|
||||
|
||||
if body_elm:
|
||||
# What's the title
|
||||
main_title = soup.title.string
|
||||
if title_elm and len(title_elm) > 0:
|
||||
main_title = title_elm.string
|
||||
webpage_text = "# " + main_title + "\n\n" + markdownify.MarkdownConverter().convert_soup(body_elm)
|
||||
else:
|
||||
webpage_text = markdownify.MarkdownConverter().convert_soup(soup)
|
||||
else:
|
||||
webpage_text = markdownify.MarkdownConverter().convert_soup(soup)
|
||||
|
||||
# Convert newlines
|
||||
webpage_text = re.sub(r"\r\n", "\n", webpage_text)
|
||||
webpage_text = re.sub(r"\n{2,}", "\n\n", webpage_text).strip()
|
||||
webpage_text = "# " + title + "\n\n" + webpage_text
|
||||
return webpage_text
|
||||
|
||||
|
||||
def get_file_from_url(url: str, save_path: str = None):
|
||||
"""Download a file from a URL."""
|
||||
if save_path is None:
|
||||
os.makedirs("/tmp/chromadb", exist_ok=True)
|
||||
save_path = os.path.join("/tmp/chromadb", os.path.basename(url))
|
||||
save_path = "tmp/chromadb"
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
if os.path.isdir(save_path):
|
||||
filename = os.path.basename(url)
|
||||
if filename == "": # "www.example.com/"
|
||||
filename = url.split("/")[-2]
|
||||
save_path = os.path.join(save_path, filename)
|
||||
else:
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
with requests.get(url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
|
||||
custom_headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36",
|
||||
}
|
||||
try:
|
||||
response = requests.get(url, stream=True, headers=custom_headers, timeout=30)
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.warning(f"Failed to download {url}, {e}")
|
||||
return None
|
||||
|
||||
content_type = response.headers.get("content-type", "")
|
||||
if "text/html" in content_type:
|
||||
# Get the content of the response
|
||||
html = ""
|
||||
for chunk in response.iter_content(chunk_size=8192, decode_unicode=True):
|
||||
html += chunk
|
||||
text = parse_html_to_markdown(html, url)
|
||||
with open(save_path, "w", encoding="utf-8") as f:
|
||||
f.write(text)
|
||||
else:
|
||||
with open(save_path, "wb") as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
return save_path
|
||||
|
||||
|
||||
2
setup.py
2
setup.py
@ -59,7 +59,7 @@ setuptools.setup(
|
||||
],
|
||||
"blendsearch": ["flaml[blendsearch]"],
|
||||
"mathchat": ["sympy", "pydantic==1.10.9", "wolframalpha"],
|
||||
"retrievechat": ["chromadb", "sentence_transformers", "pypdf", "ipython"],
|
||||
"retrievechat": ["chromadb", "sentence_transformers", "pypdf", "ipython", "beautifulsoup4", "markdownify"],
|
||||
"autobuild": ["chromadb", "sentence-transformers", "huggingface-hub"],
|
||||
"teachable": ["chromadb"],
|
||||
"lmm": ["replicate", "pillow"],
|
||||
|
||||
@ -13,6 +13,7 @@ try:
|
||||
extract_text_from_pdf,
|
||||
get_files_from_dir,
|
||||
is_url,
|
||||
parse_html_to_markdown,
|
||||
query_vector_db,
|
||||
split_files_to_chunks,
|
||||
split_text_to_chunks,
|
||||
@ -49,6 +50,18 @@ class TestRetrieveUtils:
|
||||
with pytest.raises(AssertionError):
|
||||
split_text_to_chunks("A" * 10000, chunk_mode="bogus_chunk_mode")
|
||||
|
||||
def test_split_text_to_chunks_overlapping(self):
|
||||
long_text = "\n".join([chr(i) for i in range(ord("A"), ord("Z"))])
|
||||
chunks = split_text_to_chunks(long_text, max_tokens=10, overlap=3)
|
||||
assert chunks == [
|
||||
"A\nB\nC\nD\nE\nF\nG\nH\nI",
|
||||
"G\nH\nI\nJ\nK\nL\nM\nN\nO",
|
||||
"M\nN\nO\nP\nQ\nR\nS\nT\nU",
|
||||
"S\nT\nU\nV\nW\nX\nY",
|
||||
]
|
||||
chunks = split_text_to_chunks(long_text, max_tokens=10, overlap=0)
|
||||
assert chunks == ["A\nB\nC\nD\nE\nF\nG\nH\nI", "J\nK\nL\nM\nN\nO\nP\nQ\nR", "S\nT\nU\nV\nW\nX\nY"]
|
||||
|
||||
def test_extract_text_from_pdf(self):
|
||||
pdf_file_path = os.path.join(test_dir, "example.pdf")
|
||||
assert "".join(expected_text.split()) == "".join(extract_text_from_pdf(pdf_file_path).strip().split())
|
||||
@ -236,6 +249,27 @@ class TestRetrieveUtils:
|
||||
for chunk in chunks
|
||||
)
|
||||
|
||||
def test_parse_html_to_markdown(self):
|
||||
html = """
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Simple HTML Example</title>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Hello, World!</h1>
|
||||
<p>This is a very simple HTML example.</p>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
markdown = parse_html_to_markdown(html)
|
||||
assert (
|
||||
markdown
|
||||
== "# Simple HTML Example\n\nSimple HTML Example\n\nHello, World!\n=============\n\nThis is a very simple HTML example."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user