refact: mark unit tests under the test/nodes/** path (#4235)

* document merger

* mark unit tests

* revert
This commit is contained in:
Massimiliano Pippi 2023-02-27 15:00:19 +01:00 committed by GitHub
parent efe46b1214
commit 4b8d195288
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 215 additions and 78 deletions

View File

@ -87,12 +87,14 @@ def test_crawler(tmp_path):
#
@pytest.mark.unit
def test_crawler_url_none_exception(tmp_path):
crawler = Crawler()
with pytest.raises(ValueError):
crawler.crawl()
@pytest.mark.unit
def test_crawler_depth_0_single_url(test_url, tmp_path):
crawler = Crawler(output_dir=tmp_path, crawler_depth=0, file_path_meta_field_name="file_path")
documents = crawler.crawl(urls=[test_url + "/index.html"])
@ -100,6 +102,7 @@ def test_crawler_depth_0_single_url(test_url, tmp_path):
assert content_match(crawler, test_url + "/index.html", documents[0].meta["file_path"])
@pytest.mark.unit
def test_crawler_depth_0_many_urls(test_url, tmp_path):
crawler = Crawler(output_dir=tmp_path, file_path_meta_field_name="file_path")
_urls = [test_url + "/index.html", test_url + "/page1.html"]
@ -110,6 +113,7 @@ def test_crawler_depth_0_many_urls(test_url, tmp_path):
assert content_in_results(crawler, test_url + "/page1.html", paths)
@pytest.mark.unit
def test_crawler_depth_1_single_url(test_url, tmp_path):
crawler = Crawler(output_dir=tmp_path, file_path_meta_field_name="file_path")
documents = crawler.crawl(urls=[test_url + "/index.html"], crawler_depth=1)
@ -120,6 +124,7 @@ def test_crawler_depth_1_single_url(test_url, tmp_path):
assert content_in_results(crawler, test_url + "/page2.html", paths)
@pytest.mark.unit
def test_crawler_output_file_structure(test_url, tmp_path):
crawler = Crawler(output_dir=tmp_path, file_path_meta_field_name="file_path")
documents = crawler.crawl(urls=[test_url + "/index.html"], crawler_depth=0)
@ -134,6 +139,7 @@ def test_crawler_output_file_structure(test_url, tmp_path):
assert len(data["content"].split()) > 2
@pytest.mark.unit
def test_crawler_filter_urls(test_url, tmp_path):
crawler = Crawler(output_dir=tmp_path, file_path_meta_field_name="file_path")
@ -148,6 +154,7 @@ def test_crawler_filter_urls(test_url, tmp_path):
assert not crawler.crawl(urls=[test_url + "/index.html"], filter_urls=["google.com"], crawler_depth=1)
@pytest.mark.unit
def test_crawler_extract_hidden_text(test_url, tmp_path):
crawler = Crawler(output_dir=tmp_path)
documents, _ = crawler.run(urls=[test_url + "/page_w_hidden_text.html"], extract_hidden_text=True, crawler_depth=0)
@ -159,6 +166,7 @@ def test_crawler_extract_hidden_text(test_url, tmp_path):
assert "hidden text" not in crawled_content
@pytest.mark.unit
def test_crawler_loading_wait_time(test_url, tmp_path):
loading_wait_time = 3
crawler = Crawler(output_dir=tmp_path, file_path_meta_field_name="file_path")
@ -189,6 +197,7 @@ def test_crawler_loading_wait_time(test_url, tmp_path):
assert content_in_results(crawler, test_url + "/page2.html", paths)
@pytest.mark.unit
def test_crawler_default_naming_function(test_url, tmp_path):
crawler = Crawler(output_dir=tmp_path, file_path_meta_field_name="file_path")
@ -204,6 +213,7 @@ def test_crawler_default_naming_function(test_url, tmp_path):
assert path == Path(expected_crawled_file_path)
@pytest.mark.unit
def test_crawler_naming_function(test_url, tmp_path):
crawler = Crawler(
output_dir=tmp_path,
@ -221,12 +231,14 @@ def test_crawler_naming_function(test_url, tmp_path):
assert path == expected_crawled_file_path
@pytest.mark.unit
def test_crawler_not_save_file(test_url):
crawler = Crawler()
documents = crawler.crawl(urls=[test_url + "/index.html"], crawler_depth=0)
assert documents[0].meta.get("file_path", None) is None
@pytest.mark.unit
def test_crawler_custom_meta_file_path_name(test_url, tmp_path):
crawler = Crawler()
documents = crawler.crawl(

View File

@ -1,73 +1,82 @@
import pytest
from haystack import Document
from haystack.nodes.other.document_merger import DocumentMerger
doc_dicts = [
{
"meta": {
"name": "name_1",
"year": "2020",
"month": "01",
"flat_field": 1,
"nested_field": {1: 2, "a": 5, "c": {"3": 3}, "d": "I will be dropped by the meta merge algorithm"},
},
"content": "text_1",
},
{
"meta": {
"name": "name_2",
"year": "2020",
"month": "02",
"flat_field": 1,
"nested_field": {1: 2, "a": 5, "c": {"3": 3}},
},
"content": "text_2",
},
{
"meta": {
"name": "name_3",
"year": "2020",
"month": "03",
"flat_field": 1,
"nested_field": {1: 2, "a": 7, "c": {"3": 3}},
},
"content": "text_3",
},
{
"meta": {
"name": "name_4",
"year": "2021",
"month": "01",
"flat_field": 1,
"nested_field": {1: 2, "a": 5, "c": {"3": 3}},
},
"content": "text_4",
},
{
"meta": {
"name": "name_5",
"year": "2021",
"month": "02",
"flat_field": 1,
"nested_field": {1: 2, "a": 5, "c": {"3": 3}},
},
"content": "text_5",
},
{
"meta": {
"name": "name_6",
"year": "2021",
"month": "03",
"flat_field": 1,
"nested_field": {1: 2, "a": 5, "c": {"3": 3}},
},
"content": "text_6",
},
]
documents = [Document.from_dict(doc) for doc in doc_dicts]
@pytest.fixture
def doc_dicts():
return [
{
"meta": {
"name": "name_1",
"year": "2020",
"month": "01",
"flat_field": 1,
"nested_field": {1: 2, "a": 5, "c": {"3": 3}, "d": "I will be dropped by the meta merge algorithm"},
},
"content": "text_1",
},
{
"meta": {
"name": "name_2",
"year": "2020",
"month": "02",
"flat_field": 1,
"nested_field": {1: 2, "a": 5, "c": {"3": 3}},
},
"content": "text_2",
},
{
"meta": {
"name": "name_3",
"year": "2020",
"month": "03",
"flat_field": 1,
"nested_field": {1: 2, "a": 7, "c": {"3": 3}},
},
"content": "text_3",
},
{
"meta": {
"name": "name_4",
"year": "2021",
"month": "01",
"flat_field": 1,
"nested_field": {1: 2, "a": 5, "c": {"3": 3}},
},
"content": "text_4",
},
{
"meta": {
"name": "name_5",
"year": "2021",
"month": "02",
"flat_field": 1,
"nested_field": {1: 2, "a": 5, "c": {"3": 3}},
},
"content": "text_5",
},
{
"meta": {
"name": "name_6",
"year": "2021",
"month": "03",
"flat_field": 1,
"nested_field": {1: 2, "a": 5, "c": {"3": 3}},
},
"content": "text_6",
},
]
def test_document_merger_merge():
@pytest.fixture
def documents(doc_dicts):
return [Document.from_dict(doc) for doc in doc_dicts]
@pytest.mark.unit
def test_document_merger_merge(documents, doc_dicts):
separator = "|"
dm = DocumentMerger(separator=separator)
merged_list = dm.merge(documents)
@ -77,7 +86,8 @@ def test_document_merger_merge():
assert merged_list[0].meta == {"flat_field": 1, "nested_field": {1: 2, "c": {"3": 3}}}
def test_document_merger_run():
@pytest.mark.unit
def test_document_merger_run(documents, doc_dicts):
separator = "|"
dm = DocumentMerger(separator=separator)
result = dm.run(documents)
@ -87,7 +97,8 @@ def test_document_merger_run():
assert result[0]["documents"][0].meta == {"flat_field": 1, "nested_field": {1: 2, "c": {"3": 3}}}
def test_document_merger_run_batch():
@pytest.mark.unit
def test_document_merger_run_batch(documents, doc_dicts):
separator = "|"
dm = DocumentMerger(separator=separator)
batch_result = dm.run_batch([documents, documents])

View File

@ -165,12 +165,14 @@ def test_language_validation(Converter, caplog):
assert "sample_pdf_1.pdf is not one of ['de']." in caplog.text
@pytest.mark.unit
def test_docx_converter():
converter = DocxToTextConverter()
document = converter.convert(file_path=SAMPLES_PATH / "docx" / "sample_docx.docx")[0]
assert document.content.startswith("Sample Docx File")
@pytest.mark.unit
def test_markdown_converter():
converter = MarkdownConverter()
document = converter.convert(file_path=SAMPLES_PATH / "markdown" / "sample.md")[0]
@ -178,6 +180,7 @@ def test_markdown_converter():
assert "# git clone https://github.com/deepset-ai/haystack.git" not in document.content
@pytest.mark.unit
def test_markdown_converter_headline_extraction():
expected_headlines = [
("What to build with Haystack", 1),
@ -203,6 +206,7 @@ def test_markdown_converter_headline_extraction():
assert extracted_headline["headline"] == document.content[start_idx : start_idx + hl_len]
@pytest.mark.unit
def test_markdown_converter_frontmatter_to_meta():
converter = MarkdownConverter(add_frontmatter_to_meta=True)
document = converter.convert(file_path=SAMPLES_PATH / "markdown" / "sample.md")[0]
@ -210,6 +214,7 @@ def test_markdown_converter_frontmatter_to_meta():
assert document.meta["date"] == "1.1.2023"
@pytest.mark.unit
def test_markdown_converter_remove_code_snippets():
converter = MarkdownConverter(remove_code_snippets=False)
document = converter.convert(file_path=SAMPLES_PATH / "markdown" / "sample.md")[0]
@ -302,6 +307,7 @@ def test_parsr_converter_headline_extraction():
assert extracted_headline["headline"] == doc.content[start_idx : start_idx + hl_len]
@pytest.mark.unit
def test_id_hash_keys_from_pipeline_params():
doc_path = SAMPLES_PATH / "docs" / "doc_1.txt"
meta_1 = {"key": "a"}
@ -317,13 +323,14 @@ def test_id_hash_keys_from_pipeline_params():
assert len(unique_ids) == 2
@pytest.mark.unit
def write_as_csv(data: List[List[str]], file_path: Path):
with open(file_path, "w") as f:
writer = csv.writer(f)
writer.writerows(data)
@pytest.mark.integration
@pytest.mark.unit
def test_csv_to_document_with_qa_headers(tmp_path):
node = CsvTextConverter()
csv_path = tmp_path / "csv_qa_with_headers.csv"
@ -344,7 +351,7 @@ def test_csv_to_document_with_qa_headers(tmp_path):
assert doc.meta["answer"] == "Haystack is an NLP Framework to use transformers in your Applications."
@pytest.mark.integration
@pytest.mark.unit
def test_csv_to_document_with_wrong_qa_headers(tmp_path):
node = CsvTextConverter()
csv_path = tmp_path / "csv_qa_with_wrong_headers.csv"
@ -358,7 +365,7 @@ def test_csv_to_document_with_wrong_qa_headers(tmp_path):
node.run(file_paths=csv_path)
@pytest.mark.integration
@pytest.mark.unit
def test_csv_to_document_with_one_wrong_qa_headers(tmp_path):
node = CsvTextConverter()
csv_path = tmp_path / "csv_qa_with_wrong_headers.csv"
@ -372,7 +379,7 @@ def test_csv_to_document_with_one_wrong_qa_headers(tmp_path):
node.run(file_paths=csv_path)
@pytest.mark.integration
@pytest.mark.unit
def test_csv_to_document_with_another_wrong_qa_headers(tmp_path):
node = CsvTextConverter()
csv_path = tmp_path / "csv_qa_with_wrong_headers.csv"
@ -386,7 +393,7 @@ def test_csv_to_document_with_another_wrong_qa_headers(tmp_path):
node.run(file_paths=csv_path)
@pytest.mark.integration
@pytest.mark.unit
def test_csv_to_document_with_one_column(tmp_path):
node = CsvTextConverter()
csv_path = tmp_path / "csv_qa_with_wrong_headers.csv"
@ -397,7 +404,7 @@ def test_csv_to_document_with_one_column(tmp_path):
node.run(file_paths=csv_path)
@pytest.mark.integration
@pytest.mark.unit
def test_csv_to_document_with_three_columns(tmp_path):
node = CsvTextConverter()
csv_path = tmp_path / "csv_qa_with_wrong_headers.csv"
@ -411,7 +418,7 @@ def test_csv_to_document_with_three_columns(tmp_path):
node.run(file_paths=csv_path)
@pytest.mark.integration
@pytest.mark.unit
def test_csv_to_document_many_files(tmp_path):
csv_paths = []
for i in range(5):
@ -439,7 +446,7 @@ def test_csv_to_document_many_files(tmp_path):
assert doc.meta["answer"] == f"{i}. Haystack is an NLP Framework to use transformers in your Applications."
@pytest.mark.integration
@pytest.mark.unit
class TestJsonConverter:
JSON_FILE_NAME = "json_normal.json"
JSONL_FILE_NAME = "json_normal.jsonl"

View File

@ -9,6 +9,7 @@ from haystack.nodes.file_classifier.file_type import FileTypeClassifier, DEFAULT
from ..conftest import SAMPLES_PATH
@pytest.mark.unit
def test_filetype_classifier_single_file(tmp_path):
node = FileTypeClassifier()
test_files = [tmp_path / f"test.{extension}" for extension in DEFAULT_TYPES]
@ -19,6 +20,7 @@ def test_filetype_classifier_single_file(tmp_path):
assert output == {"file_paths": [test_file]}
@pytest.mark.unit
def test_filetype_classifier_many_files(tmp_path):
node = FileTypeClassifier()
@ -30,6 +32,7 @@ def test_filetype_classifier_many_files(tmp_path):
assert output == {"file_paths": test_files}
@pytest.mark.unit
def test_filetype_classifier_many_files_mixed_extensions(tmp_path):
node = FileTypeClassifier()
test_files = [tmp_path / f"test.{extension}" for extension in DEFAULT_TYPES]
@ -38,6 +41,7 @@ def test_filetype_classifier_many_files_mixed_extensions(tmp_path):
node.run(test_files)
@pytest.mark.unit
def test_filetype_classifier_unsupported_extension(tmp_path):
node = FileTypeClassifier()
test_file = tmp_path / f"test.really_weird_extension"
@ -45,6 +49,7 @@ def test_filetype_classifier_unsupported_extension(tmp_path):
node.run(test_file)
@pytest.mark.unit
def test_filetype_classifier_custom_extensions(tmp_path):
node = FileTypeClassifier(supported_types=["my_extension"])
test_file = tmp_path / f"test.my_extension"
@ -53,12 +58,14 @@ def test_filetype_classifier_custom_extensions(tmp_path):
assert output == {"file_paths": [test_file]}
@pytest.mark.unit
def test_filetype_classifier_duplicate_custom_extensions():
with pytest.raises(ValueError):
FileTypeClassifier(supported_types=[f"my_extension", "my_extension"])
@pytest.mark.skipif(platform.system() == "Windows", reason="python-magic does not install properly on windows")
@pytest.mark.unit
@pytest.mark.skipif(platform.system() in ["Windows", "Darwin"], reason="python-magic not available")
def test_filetype_classifier_text_files_without_extension():
tested_types = ["docx", "html", "odt", "pdf", "pptx", "txt"]
node = FileTypeClassifier(supported_types=tested_types)
@ -70,7 +77,8 @@ def test_filetype_classifier_text_files_without_extension():
assert output == {"file_paths": [test_file]}
@pytest.mark.skipif(platform.system() == "Windows", reason="python-magic does not install properly on windows")
@pytest.mark.unit
@pytest.mark.skipif(platform.system() in ["Windows", "Darwin"], reason="python-magic not available")
def test_filetype_classifier_other_files_without_extension():
tested_types = ["gif", "jpg", "mp3", "png", "wav", "zip"]
node = FileTypeClassifier(supported_types=tested_types)
@ -82,8 +90,13 @@ def test_filetype_classifier_other_files_without_extension():
assert output == {"file_paths": [test_file]}
@pytest.mark.unit
def test_filetype_classifier_text_files_without_extension_no_magic(monkeypatch, caplog):
monkeypatch.delattr(haystack.nodes.file_classifier.file_type, "magic")
try:
monkeypatch.delattr(haystack.nodes.file_classifier.file_type, "magic")
except AttributeError:
# magic not installed, even better
pass
node = FileTypeClassifier(supported_types=[""])

View File

@ -89,6 +89,7 @@ def patched_nltk_data_path(module_tmp_dir: Path, monkeypatch: MonkeyPatch, tmp_p
return tmp_path
@pytest.mark.unit
@pytest.mark.parametrize("split_length_and_results", [(1, 15), (10, 2)])
def test_preprocess_sentence_split(split_length_and_results):
split_length, expected_documents_count = split_length_and_results
@ -101,6 +102,7 @@ def test_preprocess_sentence_split(split_length_and_results):
assert len(documents) == expected_documents_count
@pytest.mark.unit
@pytest.mark.parametrize("split_length_and_results", [(1, 15), (10, 2)])
def test_preprocess_sentence_split_custom_models_wrong_file_format(split_length_and_results):
split_length, expected_documents_count = split_length_and_results
@ -118,6 +120,7 @@ def test_preprocess_sentence_split_custom_models_wrong_file_format(split_length_
assert len(documents) == expected_documents_count
@pytest.mark.unit
@pytest.mark.parametrize("split_length_and_results", [(1, 15), (10, 2)])
def test_preprocess_sentence_split_custom_models_non_default_language(split_length_and_results):
split_length, expected_documents_count = split_length_and_results
@ -134,6 +137,7 @@ def test_preprocess_sentence_split_custom_models_non_default_language(split_leng
assert len(documents) == expected_documents_count
@pytest.mark.unit
@pytest.mark.parametrize("split_length_and_results", [(1, 8), (8, 1)])
def test_preprocess_sentence_split_custom_models(split_length_and_results):
split_length, expected_documents_count = split_length_and_results
@ -151,6 +155,7 @@ def test_preprocess_sentence_split_custom_models(split_length_and_results):
assert len(documents) == expected_documents_count
@pytest.mark.unit
def test_preprocess_word_split():
document = Document(content=TEXT)
preprocessor = PreProcessor(
@ -178,6 +183,7 @@ def test_preprocess_word_split():
assert len(documents) == 15
@pytest.mark.unit
@pytest.mark.parametrize("split_length_and_results", [(1, 3), (2, 2)])
def test_preprocess_passage_split(split_length_and_results):
split_length, expected_documents_count = split_length_and_results
@ -206,6 +212,7 @@ def test_clean_header_footer():
assert "footer" not in documents[0].content
@pytest.mark.unit
def test_remove_substrings():
document = Document(content="This is a header. Some additional text. wiki. Some emoji ✨ 🪲 Weird whitespace\b\b\b.")
@ -226,6 +233,7 @@ def test_remove_substrings():
assert "" in documents[0].content
@pytest.mark.unit
def test_id_hash_keys_from_pipeline_params():
document_1 = Document(content="This is a document.", meta={"key": "a"})
document_2 = Document(content="This is a document.", meta={"key": "b"})
@ -242,6 +250,7 @@ def test_id_hash_keys_from_pipeline_params():
# test_input is a tuple consisting of the parameters for split_length, split_overlap and split_respect_sentence_boundary
# and the expected index in the output list of Documents where the page number changes from 1 to 2
@pytest.mark.unit
@pytest.mark.parametrize("test_input", [(10, 0, True, 5), (10, 0, False, 4), (10, 5, True, 6), (10, 5, False, 7)])
def test_page_number_extraction(test_input):
split_length, overlap, resp_sent_boundary, exp_doc_index = test_input
@ -261,6 +270,7 @@ def test_page_number_extraction(test_input):
assert doc.meta["page"] == 2
@pytest.mark.unit
def test_page_number_extraction_on_empty_pages():
"""
Often "marketing" documents contain pages without text (visuals only). When extracting page numbers, these pages should be counted as well to avoid
@ -283,6 +293,7 @@ def test_page_number_extraction_on_empty_pages():
assert documents[1].content.strip() == text_page_three
@pytest.mark.unit
def test_headline_processing_split_by_word():
expected_headlines = [
[{"headline": "sample sentence in paragraph_1", "start_idx": 11, "level": 0}],
@ -313,6 +324,7 @@ def test_headline_processing_split_by_word():
assert doc.meta["headlines"] == expected
@pytest.mark.unit
def test_headline_processing_split_by_word_overlap():
expected_headlines = [
[{"headline": "sample sentence in paragraph_1", "start_idx": 11, "level": 0}],
@ -347,6 +359,7 @@ def test_headline_processing_split_by_word_overlap():
assert doc.meta["headlines"] == expected
@pytest.mark.unit
def test_headline_processing_split_by_word_respect_sentence_boundary():
expected_headlines = [
[{"headline": "sample sentence in paragraph_1", "start_idx": 11, "level": 0}],
@ -378,6 +391,7 @@ def test_headline_processing_split_by_word_respect_sentence_boundary():
assert doc.meta["headlines"] == expected
@pytest.mark.unit
def test_headline_processing_split_by_sentence():
expected_headlines = [
[
@ -408,6 +422,7 @@ def test_headline_processing_split_by_sentence():
assert doc.meta["headlines"] == expected
@pytest.mark.unit
def test_headline_processing_split_by_sentence_overlap():
expected_headlines = [
[
@ -441,6 +456,7 @@ def test_headline_processing_split_by_sentence_overlap():
assert doc.meta["headlines"] == expected
@pytest.mark.unit
def test_headline_processing_split_by_passage():
expected_headlines = [
[
@ -471,6 +487,7 @@ def test_headline_processing_split_by_passage():
assert doc.meta["headlines"] == expected
@pytest.mark.unit
def test_headline_processing_split_by_passage_overlap():
expected_headlines = [
[
@ -499,6 +516,7 @@ def test_headline_processing_split_by_passage_overlap():
assert doc.meta["headlines"] == expected
@pytest.mark.unit
def test_file_exists_error_during_download(monkeypatch: MonkeyPatch, module_tmp_dir: Path):
# Pretend the model resources were not found in the first attempt
monkeypatch.setattr(nltk.data, "find", Mock(side_effect=[LookupError, str(module_tmp_dir)]))
@ -510,6 +528,7 @@ def test_file_exists_error_during_download(monkeypatch: MonkeyPatch, module_tmp_
PreProcessor(split_length=2, split_respect_sentence_boundary=False)
@pytest.mark.unit
def test_preprocessor_very_long_document(caplog):
preproc = PreProcessor(
clean_empty_lines=False, clean_header_footer=False, clean_whitespace=False, split_by=None, max_chars_check=10

View File

@ -15,6 +15,7 @@ def is_openai_api_key_set(api_key: str):
return len(api_key) > 0 and api_key != "KEY_NOT_FOUND"
@pytest.mark.unit
def test_prompt_templates():
p = PromptTemplate("t1", "Here is some fake template with variable $foo", ["foo"])
@ -53,6 +54,7 @@ def test_prompt_templates():
assert p.prompt_text == "Here is some fake template with variable $baz"
@pytest.mark.unit
def test_prompt_template_repr():
p = PromptTemplate("t", "Here is variable $baz")
desired_repr = "PromptTemplate(name=t, prompt_text=Here is variable $baz, prompt_params=['baz'])"
@ -60,6 +62,7 @@ def test_prompt_template_repr():
assert str(p) == desired_repr
@pytest.mark.integration
def test_create_prompt_model():
model = PromptModel("google/flan-t5-small")
assert model.model_name_or_path == "google/flan-t5-small"
@ -99,6 +102,7 @@ def test_create_prompt_model_dtype():
assert model.model_name_or_path == "google/flan-t5-small"
@pytest.mark.integration
def test_create_prompt_node():
prompt_node = PromptNode()
assert prompt_node is not None
@ -122,6 +126,7 @@ def test_create_prompt_node():
PromptNode("some-random-model")
@pytest.mark.unit
def test_add_and_remove_template(prompt_node):
num_default_tasks = len(prompt_node.get_prompt_template_names())
custom_task = PromptTemplate(
@ -135,7 +140,8 @@ def test_add_and_remove_template(prompt_node):
assert "custom-task" not in prompt_node.get_prompt_template_names()
def test_invalid_template(prompt_node):
@pytest.mark.unit
def test_invalid_template():
with pytest.raises(ValueError, match="Invalid parameter"):
PromptTemplate(
name="custom-task", prompt_text="Custom task: $pram1 $param2", prompt_params=["param1", "param2"]
@ -145,6 +151,7 @@ def test_invalid_template(prompt_node):
PromptTemplate(name="custom-task", prompt_text="Custom task: $param1", prompt_params=["param1", "param2"])
@pytest.mark.integration
def test_add_template_and_invoke(prompt_node):
tt = PromptTemplate(
name="sentiment-analysis-new",
@ -158,6 +165,7 @@ def test_add_template_and_invoke(prompt_node):
assert r[0].casefold() == "positive"
@pytest.mark.integration
def test_on_the_fly_prompt(prompt_node):
tt = PromptTemplate(
name="sentiment-analysis-temp",
@ -169,6 +177,7 @@ def test_on_the_fly_prompt(prompt_node):
assert r[0].casefold() == "positive"
@pytest.mark.integration
def test_direct_prompting(prompt_node):
r = prompt_node("What is the capital of Germany?")
assert r[0].casefold() == "berlin"
@ -184,11 +193,13 @@ def test_direct_prompting(prompt_node):
assert len(r) == 2
@pytest.mark.integration
def test_question_generation(prompt_node):
r = prompt_node.prompt("question-generation", documents=["Berlin is the capital of Germany."])
assert len(r) == 1 and len(r[0]) > 0
@pytest.mark.integration
def test_template_selection(prompt_node):
qa = prompt_node.set_default_prompt_template("question-answering")
r = qa(
@ -198,26 +209,31 @@ def test_template_selection(prompt_node):
assert r[0].casefold() == "berlin" and r[1].casefold() == "paris"
@pytest.mark.integration
def test_has_supported_template_names(prompt_node):
assert len(prompt_node.get_prompt_template_names()) > 0
@pytest.mark.integration
def test_invalid_template_params(prompt_node):
with pytest.raises(ValueError, match="Expected prompt parameters"):
prompt_node.prompt("question-answering", {"some_crazy_key": "Berlin is the capital of Germany."})
@pytest.mark.integration
def test_wrong_template_params(prompt_node):
with pytest.raises(ValueError, match="Expected prompt parameters"):
# with don't have options param, multiple choice QA has
prompt_node.prompt("question-answering", options=["Berlin is the capital of Germany."])
@pytest.mark.integration
def test_run_invalid_template(prompt_node):
with pytest.raises(ValueError, match="invalid-task not supported"):
prompt_node.prompt("invalid-task", {})
@pytest.mark.integration
def test_invalid_prompting(prompt_node):
with pytest.raises(ValueError, match="Hey there, what is the best city in the worl"):
prompt_node.prompt(
@ -228,6 +244,7 @@ def test_invalid_prompting(prompt_node):
prompt_node.prompt(["Hey there, what is the best city in the world?", "Hey, answer me!"])
@pytest.mark.integration
def test_invalid_state_ops(prompt_node):
with pytest.raises(ValueError, match="Prompt template no_such_task_exists"):
prompt_node.remove_prompt_template("no_such_task_exists")
@ -405,6 +422,7 @@ def test_complex_pipeline_with_qa(prompt_model):
)
@pytest.mark.integration
def test_complex_pipeline_with_shared_model():
model = PromptModel()
node = PromptNode(
@ -420,6 +438,7 @@ def test_complex_pipeline_with_shared_model():
assert result["results"][0] == "Berlin"
@pytest.mark.integration
def test_simple_pipeline_yaml(tmp_path):
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
@ -443,6 +462,7 @@ def test_simple_pipeline_yaml(tmp_path):
assert result["results"][0] == "positive"
@pytest.mark.integration
def test_simple_pipeline_yaml_with_default_params(tmp_path):
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
@ -470,6 +490,7 @@ def test_simple_pipeline_yaml_with_default_params(tmp_path):
assert result["results"][0] == "positive"
@pytest.mark.integration
def test_complex_pipeline_yaml(tmp_path):
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
@ -505,6 +526,7 @@ def test_complex_pipeline_yaml(tmp_path):
assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0
@pytest.mark.integration
def test_complex_pipeline_with_shared_prompt_model_yaml(tmp_path):
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
@ -544,6 +566,7 @@ def test_complex_pipeline_with_shared_prompt_model_yaml(tmp_path):
assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0
@pytest.mark.integration
def test_complex_pipeline_with_shared_prompt_model_and_prompt_template_yaml(tmp_path):
with open(tmp_path / "tmp_config_with_prompt_template.yml", "w") as tmp_file:
tmp_file.write(
@ -592,6 +615,7 @@ def test_complex_pipeline_with_shared_prompt_model_and_prompt_template_yaml(tmp_
assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0
@pytest.mark.integration
def test_complex_pipeline_with_with_dummy_node_between_prompt_nodes_yaml(tmp_path):
# test that we can stick some random node in between prompt nodes and that everything still works
# most specifically, we want to ensure that invocation_context is still populated correctly and propagated
@ -672,6 +696,7 @@ def test_complex_pipeline_with_with_dummy_node_between_prompt_nodes_yaml(tmp_pat
assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0
@pytest.mark.integration
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
@ -733,6 +758,7 @@ def test_complex_pipeline_with_all_features(tmp_path):
assert "questions" in result["invocation_context"] and len(result["invocation_context"]["questions"]) > 0
@pytest.mark.integration
def test_complex_pipeline_with_multiple_same_prompt_node_components_yaml(tmp_path):
# p2 and p3 are essentially the same PromptNode component, make sure we can use them both as is in the pipeline
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
@ -843,6 +869,7 @@ class TestRunBatch:
# TODO Finish
@pytest.mark.integration
def test_HFLocalInvocationLayer_supports():
assert HFLocalInvocationLayer.supports("philschmid/flan-t5-base-samsum")
assert HFLocalInvocationLayer.supports("bigscience/T0_3B")

View File

@ -22,12 +22,14 @@ def mock_function_two_outputs(monkeypatch):
)
@pytest.mark.unit
def test_basic_invocation_only_inputs(mock_function):
shaper = Shaper(func="test_function", inputs={"a": "query", "b": "documents"}, outputs=["c"])
results, _ = shaper.run(query="test query", documents=["doesn't", "really", "matter"])
assert results["invocation_context"]["c"] == ["test query", "test query", "test query"]
@pytest.mark.unit
def test_multiple_outputs(mock_function_two_outputs):
shaper = Shaper(func="two_output_test_function", inputs={"a": "query"}, outputs=["c", "d"])
results, _ = shaper.run(query="test")
@ -35,6 +37,7 @@ def test_multiple_outputs(mock_function_two_outputs):
assert results["invocation_context"]["d"] == 4
@pytest.mark.unit
def test_multiple_outputs_error(mock_function_two_outputs, caplog):
shaper = Shaper(func="two_output_test_function", inputs={"a": "query"}, outputs=["c"])
with caplog.at_level(logging.WARNING):
@ -42,18 +45,21 @@ def test_multiple_outputs_error(mock_function_two_outputs, caplog):
assert "Only 1 output(s) will be stored." in caplog.text
@pytest.mark.unit
def test_basic_invocation_only_params(mock_function):
shaper = Shaper(func="test_function", params={"a": "A", "b": list(range(3))}, outputs=["c"])
results, _ = shaper.run()
assert results["invocation_context"]["c"] == ["A", "A", "A"]
@pytest.mark.unit
def test_basic_invocation_inputs_and_params(mock_function):
shaper = Shaper(func="test_function", inputs={"a": "query"}, params={"b": list(range(2))}, outputs=["c"])
results, _ = shaper.run(query="test query")
assert results["invocation_context"]["c"] == ["test query", "test query"]
@pytest.mark.unit
def test_basic_invocation_inputs_and_params_colliding(mock_function):
shaper = Shaper(
func="test_function", inputs={"a": "query"}, params={"a": "default value", "b": list(range(2))}, outputs=["c"]
@ -62,6 +68,7 @@ def test_basic_invocation_inputs_and_params_colliding(mock_function):
assert results["invocation_context"]["c"] == ["test query", "test query"]
@pytest.mark.unit
def test_basic_invocation_inputs_and_params_using_params_as_defaults(mock_function):
shaper = Shaper(
func="test_function", inputs={"a": "query"}, params={"a": "default", "b": list(range(2))}, outputs=["c"]
@ -70,12 +77,14 @@ def test_basic_invocation_inputs_and_params_using_params_as_defaults(mock_functi
assert results["invocation_context"]["c"] == ["default", "default"]
@pytest.mark.unit
def test_missing_argument(mock_function):
shaper = Shaper(func="test_function", inputs={"b": "documents"}, outputs=["c"])
with pytest.raises(ValueError, match="Shaper couldn't apply the function to your inputs and parameters."):
shaper.run(query="test query", documents=["doesn't", "really", "matter"])
@pytest.mark.unit
def test_excess_argument(mock_function):
shaper = Shaper(
func="test_function", inputs={"a": "query", "b": "documents", "something_extra": "query"}, outputs=["c"]
@ -84,12 +93,14 @@ def test_excess_argument(mock_function):
shaper.run(query="test query", documents=["doesn't", "really", "matter"])
@pytest.mark.unit
def test_value_not_in_invocation_context(mock_function):
shaper = Shaper(func="test_function", inputs={"a": "query", "b": "something_that_does_not_exist"}, outputs=["c"])
with pytest.raises(ValueError, match="Shaper couldn't apply the function to your inputs and parameters."):
shaper.run(query="test query", documents=["doesn't", "really", "matter"])
@pytest.mark.unit
def test_value_only_in_invocation_context(mock_function):
shaper = Shaper(func="test_function", inputs={"a": "query", "b": "invocation_context_specific"}, outputs=["c"])
results, _s = shaper.run(
@ -98,6 +109,7 @@ def test_value_only_in_invocation_context(mock_function):
assert results["invocation_context"]["c"] == ["test query", "test query", "test query"]
@pytest.mark.unit
def test_yaml(mock_function, tmp_path):
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
@ -136,12 +148,14 @@ def test_yaml(mock_function, tmp_path):
#
@pytest.mark.unit
def test_rename():
shaper = Shaper(func="rename", inputs={"value": "query"}, outputs=["questions"])
results, _ = shaper.run(query="test query")
assert results["invocation_context"]["questions"] == "test query"
@pytest.mark.unit
def test_rename_yaml(tmp_path):
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
@ -175,12 +189,14 @@ def test_rename_yaml(tmp_path):
#
@pytest.mark.unit
def test_value_to_list():
shaper = Shaper(func="value_to_list", inputs={"value": "query", "target_list": "documents"}, outputs=["questions"])
results, _ = shaper.run(query="test query", documents=["doesn't", "really", "matter"])
assert results["invocation_context"]["questions"] == ["test query", "test query", "test query"]
@pytest.mark.unit
def test_value_to_list_yaml(tmp_path):
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
@ -219,12 +235,14 @@ def test_value_to_list_yaml(tmp_path):
#
@pytest.mark.unit
def test_join_lists():
shaper = Shaper(func="join_lists", params={"lists": [[1, 2, 3], [4, 5]]}, outputs=["list"])
results, _ = shaper.run()
assert results["invocation_context"]["list"] == [1, 2, 3, 4, 5]
@pytest.mark.unit
def test_join_lists_yaml(tmp_path):
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
@ -259,6 +277,7 @@ def test_join_lists_yaml(tmp_path):
#
@pytest.mark.unit
def test_join_strings():
shaper = Shaper(
func="join_strings", params={"strings": ["first", "second"], "delimiter": " | "}, outputs=["single_string"]
@ -267,12 +286,14 @@ def test_join_strings():
assert results["invocation_context"]["single_string"] == "first | second"
@pytest.mark.unit
def test_join_strings_default_delimiter():
shaper = Shaper(func="join_strings", params={"strings": ["first", "second"]}, outputs=["single_string"])
results, _ = shaper.run()
assert results["invocation_context"]["single_string"] == "first second"
@pytest.mark.unit
def test_join_strings_yaml(tmp_path):
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
@ -302,6 +323,7 @@ def test_join_strings_yaml(tmp_path):
assert result["invocation_context"]["single_string"] == "first - second - third"
@pytest.mark.unit
def test_join_strings_default_delimiter_yaml(tmp_path):
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
@ -334,6 +356,7 @@ def test_join_strings_default_delimiter_yaml(tmp_path):
#
@pytest.mark.unit
def test_join_documents():
shaper = Shaper(
func="join_documents", inputs={"documents": "documents"}, params={"delimiter": " | "}, outputs=["documents"]
@ -375,6 +398,7 @@ def test_join_documents_with_publish_outputs_as_list():
assert results["documents"] == [Document(content="first | second | third")]
@pytest.mark.unit
def test_join_documents_default_delimiter():
shaper = Shaper(func="join_documents", inputs={"documents": "documents"}, outputs=["documents"])
results, _ = shaper.run(
@ -383,6 +407,7 @@ def test_join_documents_default_delimiter():
assert results["invocation_context"]["documents"] == [Document(content="first second third")]
@pytest.mark.unit
def test_join_documents_yaml(tmp_path):
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
@ -417,6 +442,7 @@ def test_join_documents_yaml(tmp_path):
assert result["documents"] == [Document(content="first - second - third")]
@pytest.mark.unit
def test_join_documents_default_delimiter_yaml(tmp_path):
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
@ -451,6 +477,7 @@ def test_join_documents_default_delimiter_yaml(tmp_path):
#
@pytest.mark.unit
def test_strings_to_answers_no_meta_no_hashkeys():
shaper = Shaper(func="strings_to_answers", inputs={"strings": "responses"}, outputs=["answers"])
results, _ = shaper.run(invocation_context={"responses": ["first", "second", "third"]})
@ -461,6 +488,7 @@ def test_strings_to_answers_no_meta_no_hashkeys():
]
@pytest.mark.unit
def test_strings_to_answers_yaml(tmp_path):
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
@ -502,12 +530,14 @@ def test_strings_to_answers_yaml(tmp_path):
#
@pytest.mark.unit
def test_answers_to_strings():
shaper = Shaper(func="answers_to_strings", inputs={"answers": "documents"}, outputs=["strings"])
results, _ = shaper.run(documents=[Answer(answer="first"), Answer(answer="second"), Answer(answer="third")])
assert results["invocation_context"]["strings"] == ["first", "second", "third"]
@pytest.mark.unit
def test_answers_to_strings_yaml(tmp_path):
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
@ -540,6 +570,7 @@ def test_answers_to_strings_yaml(tmp_path):
#
@pytest.mark.unit
def test_strings_to_documents_no_meta_no_hashkeys():
shaper = Shaper(func="strings_to_documents", inputs={"strings": "responses"}, outputs=["documents"])
results, _ = shaper.run(invocation_context={"responses": ["first", "second", "third"]})
@ -550,6 +581,7 @@ def test_strings_to_documents_no_meta_no_hashkeys():
]
@pytest.mark.unit
def test_strings_to_documents_single_meta_no_hashkeys():
shaper = Shaper(
func="strings_to_documents", inputs={"strings": "responses"}, params={"meta": {"a": "A"}}, outputs=["documents"]
@ -562,6 +594,7 @@ def test_strings_to_documents_single_meta_no_hashkeys():
]
@pytest.mark.unit
def test_strings_to_documents_wrong_number_of_meta():
shaper = Shaper(
func="strings_to_documents",
@ -574,6 +607,7 @@ def test_strings_to_documents_wrong_number_of_meta():
shaper.run(invocation_context={"responses": ["first", "second", "third"]})
@pytest.mark.unit
def test_strings_to_documents_many_meta_no_hashkeys():
shaper = Shaper(
func="strings_to_documents",
@ -589,6 +623,7 @@ def test_strings_to_documents_many_meta_no_hashkeys():
]
@pytest.mark.unit
def test_strings_to_documents_single_meta_with_hashkeys():
shaper = Shaper(
func="strings_to_documents",
@ -604,6 +639,7 @@ def test_strings_to_documents_single_meta_with_hashkeys():
]
@pytest.mark.unit
def test_strings_to_documents_no_meta_no_hashkeys_yaml(tmp_path):
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
@ -635,6 +671,7 @@ def test_strings_to_documents_no_meta_no_hashkeys_yaml(tmp_path):
]
@pytest.mark.unit
def test_strings_to_documents_meta_and_hashkeys_yaml(tmp_path):
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
@ -676,6 +713,7 @@ def test_strings_to_documents_meta_and_hashkeys_yaml(tmp_path):
#
@pytest.mark.unit
def test_documents_to_strings():
shaper = Shaper(func="documents_to_strings", inputs={"documents": "documents"}, outputs=["strings"])
results, _ = shaper.run(
@ -684,6 +722,7 @@ def test_documents_to_strings():
assert results["invocation_context"]["strings"] == ["first", "second", "third"]
@pytest.mark.unit
def test_documents_to_strings_yaml(tmp_path):
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
@ -716,6 +755,7 @@ def test_documents_to_strings_yaml(tmp_path):
#
@pytest.mark.unit
def test_chain_shapers():
shaper_1 = Shaper(
func="join_documents", inputs={"documents": "documents"}, params={"delimiter": " - "}, outputs=["documents"]
@ -736,6 +776,7 @@ def test_chain_shapers():
assert results["invocation_context"]["questions"] == ["test query"]
@pytest.mark.unit
def test_chain_shapers_yaml(tmp_path):
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
@ -785,6 +826,7 @@ def test_chain_shapers_yaml(tmp_path):
assert results["invocation_context"]["questions"] == ["test query"]
@pytest.mark.unit
def test_chain_shapers_yaml_2(tmp_path):
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
@ -871,6 +913,7 @@ def test_chain_shapers_yaml_2(tmp_path):
assert results["invocation_context"]["documents_with_greetings"] == [Document(content="hello. hello. hello")]
@pytest.mark.integration
def test_with_prompt_node(tmp_path):
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
@ -920,6 +963,7 @@ def test_with_prompt_node(tmp_path):
assert len(result["invocation_context"]["questions"]) == 2
@pytest.mark.integration
def test_with_multiple_prompt_nodes(tmp_path):
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
@ -997,6 +1041,7 @@ def test_with_multiple_prompt_nodes(tmp_path):
assert any([True for r in results if "Berlin" in r])
@pytest.mark.unit
def test_join_query_and_documents_yaml(tmp_path):
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
@ -1042,6 +1087,7 @@ def test_join_query_and_documents_yaml(tmp_path):
assert result["query"] == ["first", "second", "third", "What is going on here?"]
@pytest.mark.unit
def test_join_query_and_documents_into_single_string_yaml(tmp_path):
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
@ -1098,6 +1144,7 @@ def test_join_query_and_documents_into_single_string_yaml(tmp_path):
assert result["query"] == "first second third What is going on here?"
@pytest.mark.unit
def test_join_query_and_documents_convert_into_documents_yaml(tmp_path):
with open(tmp_path / "tmp_config.yml", "w") as tmp_file:
tmp_file.write(
@ -1156,6 +1203,7 @@ def test_join_query_and_documents_convert_into_documents_yaml(tmp_path):
assert isinstance(result["invocation_context"]["query_and_docs"][0], Document)
@pytest.mark.unit
def test_shaper_publishes_unknown_arg_does_not_break_pipeline():
documents = [Document(content="test query")]
shaper = Shaper(func="rename", inputs={"value": "query"}, outputs=["unknown_by_retriever"], publish_outputs=True)