diff --git a/haystack/components/converters/openapi_functions.py b/haystack/components/converters/openapi_functions.py index b35386003..8315807d8 100644 --- a/haystack/components/converters/openapi_functions.py +++ b/haystack/components/converters/openapi_functions.py @@ -48,16 +48,18 @@ class OpenAPIServiceToFunctions: openapi_imports.check() @component.output_types(documents=List[Document]) - def run(self, sources: List[Union[str, Path, ByteStream]], system_messages: List[str]) -> Dict[str, Any]: + def run( + self, sources: List[Union[str, Path, ByteStream]], system_messages: Optional[List[str]] = None + ) -> Dict[str, Any]: """ Processes OpenAPI specification URLs or files to extract functions that can be invoked via OpenAI function - calling mechanism. Each source is paired with a system message in one-to-one correspondence. The system message - is used to assist LLM in the response generation. + calling mechanism. Each source is paired with an optional system message. The system message can be potentially + used in LLM response generation. :param sources: A list of OpenAPI specification sources, which can be URLs, file paths, or ByteStream objects. :type sources: List[Union[str, Path, ByteStream]] - :param system_messages: A list of system messages corresponding to each source. - :type system_messages: List[str] + :param system_messages: A list of optional system messages corresponding to each source. + :type system_messages: Optional[List[str]] :return: A dictionary with a key 'documents' containing a list of Document objects. Each Document object encapsulates a function definition and relevant metadata. :rtype: Dict[str, Any] @@ -65,6 +67,7 @@ class OpenAPIServiceToFunctions: :raises ValueError: If the source type is not recognized or no functions are found in the OpenAPI specification. """ documents: List[Document] = [] + system_messages = system_messages or [""] * len(sources) for source, system_message in zip(sources, system_messages): openapi_spec_content = None if isinstance(source, (str, Path)): @@ -83,14 +86,12 @@ class OpenAPIServiceToFunctions: try: service_openapi_spec = self._parse_openapi_spec(openapi_spec_content) functions: List[Dict[str, Any]] = self._openapi_to_functions(service_openapi_spec) - docs = [ - Document( - content=json.dumps(function), - meta={"spec": service_openapi_spec, "system_message": system_message}, - ) - for function in functions - ] - documents.extend(docs) + for function in functions: + meta: Dict[str, Any] = {"spec": service_openapi_spec} + if system_message: + meta["system_message"] = system_message + doc = Document(content=json.dumps(function), meta=meta) + documents.append(doc) except Exception as e: logger.error("Error processing OpenAPI specification from source %s: %s", source, e) diff --git a/test/components/converters/test_openapi_functions.py b/test/components/converters/test_openapi_functions.py index 79e202ee6..db704ae6a 100644 --- a/test/components/converters/test_openapi_functions.py +++ b/test/components/converters/test_openapi_functions.py @@ -219,3 +219,23 @@ class TestOpenAPIServiceToFunctions: # check that the metadata is as expected assert doc.meta["system_message"] == "Some system message we don't care about here" assert doc.meta["spec"] == json.loads(json_serperdev_openapi_spec) + + def test_run_with_file_source_and_none_system_messages(self, json_serperdev_openapi_spec): + service = OpenAPIServiceToFunctions() + spec_stream = ByteStream.from_string(json_serperdev_openapi_spec) + + # we now omit the system_messages argument + result = service.run(sources=[spec_stream]) + assert len(result["documents"]) == 1 + doc = result["documents"][0] + + # check that the content is as expected + assert ( + doc.content + == '{"name": "search", "description": "Search the web with Google", "parameters": {"type": "object", ' + '"properties": {"requestBody": {"type": "object", "properties": {"q": {"type": "string"}}}}}}' + ) + + # check that the metadata is as expected, system_message should not be present + assert "system_message" not in doc.meta + assert doc.meta["spec"] == json.loads(json_serperdev_openapi_spec)