mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-17 10:58:51 +00:00
move Cohere generator into dedicated integration (#6475)
This commit is contained in:
parent
09f898aff8
commit
a86807b834
4
.github/workflows/linting.yml
vendored
4
.github/workflows/linting.yml
vendored
@ -38,7 +38,7 @@ jobs:
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
|
||||
- name: Install Haystack
|
||||
run: pip install .[dev,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf tika 'azure-ai-formrecognizer>=3.2.0b2' cohere
|
||||
run: pip install .[dev,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf tika 'azure-ai-formrecognizer>=3.2.0b2'
|
||||
|
||||
- name: Mypy
|
||||
if: steps.files.outputs.any_changed == 'true'
|
||||
@ -69,7 +69,7 @@ jobs:
|
||||
|
||||
- name: Install Haystack
|
||||
run: |
|
||||
pip install .[dev,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2' cohere
|
||||
pip install .[dev,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2'
|
||||
|
||||
- name: Pylint
|
||||
if: steps.files.outputs.any_changed == 'true'
|
||||
|
9
.github/workflows/tests.yml
vendored
9
.github/workflows/tests.yml
vendored
@ -22,7 +22,6 @@ on:
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
|
||||
CORE_AZURE_CS_ENDPOINT: ${{ secrets.CORE_AZURE_CS_ENDPOINT }}
|
||||
CORE_AZURE_CS_API_KEY: ${{ secrets.CORE_AZURE_CS_API_KEY }}
|
||||
PYTHON_VERSION: "3.8"
|
||||
@ -99,7 +98,7 @@ jobs:
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
|
||||
- name: Install Haystack
|
||||
run: pip install .[dev,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2' cohere boilerpy3
|
||||
run: pip install .[dev,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2' boilerpy3
|
||||
|
||||
- name: Run
|
||||
run: pytest -m "not integration" test
|
||||
@ -157,7 +156,7 @@ jobs:
|
||||
sudo apt install ffmpeg # for local Whisper tests
|
||||
|
||||
- name: Install Haystack
|
||||
run: pip install .[dev,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2' cohere boilerpy3
|
||||
run: pip install .[dev,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2' boilerpy3
|
||||
|
||||
- name: Run
|
||||
run: pytest --maxfail=5 -m "integration" test
|
||||
@ -213,7 +212,7 @@ jobs:
|
||||
colima start
|
||||
|
||||
- name: Install Haystack
|
||||
run: pip install .[dev,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2' cohere boilerpy3
|
||||
run: pip install .[dev,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2' boilerpy3
|
||||
|
||||
- name: Run Tika
|
||||
run: docker run -d -p 9998:9998 apache/tika:2.9.0.0
|
||||
@ -264,7 +263,7 @@ jobs:
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
|
||||
- name: Install Haystack
|
||||
run: pip install .[dev,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2' cohere boilerpy3
|
||||
run: pip install .[dev,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2' boilerpy3
|
||||
|
||||
- name: Run
|
||||
run: pytest --maxfail=5 -m "integration" test -k 'not tika'
|
||||
|
@ -1,7 +1,7 @@
|
||||
loaders:
|
||||
- type: loaders.CustomPythonLoader
|
||||
search_path: [../../../haystack/preview/components/generators]
|
||||
modules: ["hugging_face_local", "hugging_face_tgi", "openai", "cohere", "chat/hugging_face_tgi", "chat/openai"]
|
||||
modules: ["hugging_face_local", "hugging_face_tgi", "openai", "chat/hugging_face_tgi", "chat/openai"]
|
||||
ignore_when_discovered: ["__init__"]
|
||||
processors:
|
||||
- type: filter
|
||||
|
@ -1,6 +1,5 @@
|
||||
from haystack.components.generators.cohere import CohereGenerator
|
||||
from haystack.components.generators.hugging_face_local import HuggingFaceLocalGenerator
|
||||
from haystack.components.generators.hugging_face_tgi import HuggingFaceTGIGenerator
|
||||
from haystack.components.generators.openai import GPTGenerator
|
||||
|
||||
__all__ = ["HuggingFaceLocalGenerator", "HuggingFaceTGIGenerator", "GPTGenerator", "CohereGenerator"]
|
||||
__all__ = ["HuggingFaceLocalGenerator", "HuggingFaceTGIGenerator", "GPTGenerator"]
|
||||
|
@ -1,159 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack import DeserializationError, component, default_from_dict, default_to_dict
|
||||
|
||||
with LazyImport(message="Run 'pip install cohere'") as cohere_import:
|
||||
from cohere import Client, COHERE_API_URL
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@component
|
||||
class CohereGenerator:
|
||||
"""LLM Generator compatible with Cohere's generate endpoint.
|
||||
|
||||
Queries the LLM using Cohere's API. Invocations are made using 'cohere' package.
|
||||
See [Cohere API](https://docs.cohere.com/reference/generate) for more details.
|
||||
|
||||
Example usage:
|
||||
|
||||
```python
|
||||
from haystack.generators import CohereGenerator
|
||||
generator = CohereGenerator(api_key="test-api-key")
|
||||
generator.run(prompt="What's the capital of France?")
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model_name: str = "command",
|
||||
streaming_callback: Optional[Callable] = None,
|
||||
api_base_url: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Instantiates a `CohereGenerator` component.
|
||||
:param api_key: The API key for the Cohere API. If not set, it will be read from the COHERE_API_KEY env var.
|
||||
:param model_name: The name of the model to use. Available models are: [command, command-light, command-nightly, command-nightly-light]. Defaults to "command".
|
||||
:param streaming_callback: A callback function to be called with the streaming response. Defaults to None.
|
||||
:param api_base_url: The base URL of the Cohere API. Defaults to "https://api.cohere.ai".
|
||||
:param kwargs: Additional model parameters. These will be used during generation. Refer to https://docs.cohere.com/reference/generate for more details.
|
||||
Some of the parameters are:
|
||||
- 'max_tokens': The maximum number of tokens to be generated. Defaults to 1024.
|
||||
- 'truncate': One of NONE|START|END to specify how the API will handle inputs longer than the maximum token length. Defaults to END.
|
||||
- 'temperature': A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations.
|
||||
- 'preset': Identifier of a custom preset. A preset is a combination of parameters, such as prompt, temperature etc. You can create presets in the playground.
|
||||
- 'end_sequences': The generated text will be cut at the beginning of the earliest occurrence of an end sequence. The sequence will be excluded from the text.
|
||||
- 'stop_sequences': The generated text will be cut at the end of the earliest occurrence of a stop sequence. The sequence will be included the text.
|
||||
- 'k': Defaults to 0, min value of 0.01, max value of 0.99.
|
||||
- 'p': Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
|
||||
- 'frequency_penalty': Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens,
|
||||
proportional to how many times they have already appeared in the prompt or prior generation.'
|
||||
- 'presence_penalty': Defaults to 0.0, min value of 0.0, max value of 1.0. Can be used to reduce repetitiveness of generated tokens.
|
||||
Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
|
||||
- 'return_likelihoods': One of GENERATION|ALL|NONE to specify how and if the token likelihoods are returned with the response. Defaults to NONE.
|
||||
- 'logit_bias': Used to prevent the model from generating unwanted tokens or to incentivize it to include desired tokens.
|
||||
The format is {token_id: bias} where bias is a float between -10 and 10.
|
||||
|
||||
"""
|
||||
cohere_import.check()
|
||||
|
||||
if not api_key:
|
||||
api_key = os.environ.get("COHERE_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"CohereGenerator needs an API key to run. Either provide it as init parameter or set the env var COHERE_API_KEY."
|
||||
)
|
||||
|
||||
if not api_base_url:
|
||||
api_base_url = COHERE_API_URL
|
||||
|
||||
self.api_key = api_key
|
||||
self.model_name = model_name
|
||||
self.streaming_callback = streaming_callback
|
||||
self.api_base_url = api_base_url
|
||||
self.model_parameters = kwargs
|
||||
self.client = Client(api_key=self.api_key, api_url=self.api_base_url)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize this component to a dictionary.
|
||||
"""
|
||||
if self.streaming_callback:
|
||||
module = self.streaming_callback.__module__
|
||||
if module == "builtins":
|
||||
callback_name = self.streaming_callback.__name__
|
||||
else:
|
||||
callback_name = f"{module}.{self.streaming_callback.__name__}"
|
||||
else:
|
||||
callback_name = None
|
||||
|
||||
return default_to_dict(
|
||||
self,
|
||||
model_name=self.model_name,
|
||||
streaming_callback=callback_name,
|
||||
api_base_url=self.api_base_url,
|
||||
**self.model_parameters,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "CohereGenerator":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
init_params = data.get("init_parameters", {})
|
||||
streaming_callback = None
|
||||
if "streaming_callback" in init_params and init_params["streaming_callback"]:
|
||||
parts = init_params["streaming_callback"].split(".")
|
||||
module_name = ".".join(parts[:-1])
|
||||
function_name = parts[-1]
|
||||
module = sys.modules.get(module_name, None)
|
||||
if not module:
|
||||
raise DeserializationError(f"Could not locate the module of the streaming callback: {module_name}")
|
||||
streaming_callback = getattr(module, function_name, None)
|
||||
if not streaming_callback:
|
||||
raise DeserializationError(f"Could not locate the streaming callback: {function_name}")
|
||||
data["init_parameters"]["streaming_callback"] = streaming_callback
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(replies=List[str], metadata=List[Dict[str, Any]])
|
||||
def run(self, prompt: str):
|
||||
"""
|
||||
Queries the LLM with the prompts to produce replies.
|
||||
:param prompt: The prompt to be sent to the generative model.
|
||||
"""
|
||||
response = self.client.generate(
|
||||
model=self.model_name, prompt=prompt, stream=self.streaming_callback is not None, **self.model_parameters
|
||||
)
|
||||
if self.streaming_callback:
|
||||
metadata_dict: Dict[str, Any] = {}
|
||||
for chunk in response:
|
||||
self.streaming_callback(chunk)
|
||||
metadata_dict["index"] = chunk.index
|
||||
replies = response.texts
|
||||
metadata_dict["finish_reason"] = response.finish_reason
|
||||
metadata = [metadata_dict]
|
||||
self._check_truncated_answers(metadata)
|
||||
return {"replies": replies, "metadata": metadata}
|
||||
|
||||
metadata = [{"finish_reason": resp.finish_reason} for resp in response]
|
||||
replies = [resp.text for resp in response]
|
||||
self._check_truncated_answers(metadata)
|
||||
return {"replies": replies, "metadata": metadata}
|
||||
|
||||
def _check_truncated_answers(self, metadata: List[Dict[str, Any]]):
|
||||
"""
|
||||
Check the `finish_reason` returned with the Cohere response.
|
||||
If the `finish_reason` is `MAX_TOKEN`, log a warning to the user.
|
||||
:param metadata: The metadata returned by the Cohere API.
|
||||
"""
|
||||
if metadata[0]["finish_reason"] == "MAX_TOKENS":
|
||||
logger.warning(
|
||||
"Responses have been truncated before reaching a natural stopping point. "
|
||||
"Increase the max_tokens parameter to allow for longer completions."
|
||||
)
|
@ -1,178 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.components.generators import CohereGenerator
|
||||
|
||||
|
||||
def default_streaming_callback(chunk):
|
||||
"""
|
||||
Default callback function for streaming responses from Cohere API.
|
||||
Prints the tokens of the first completion to stdout as soon as they are received and returns the chunk unchanged.
|
||||
"""
|
||||
print(chunk.text, flush=True, end="")
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestCohereGenerator:
|
||||
def test_init_default(self):
|
||||
import cohere
|
||||
|
||||
component = CohereGenerator(api_key="test-api-key")
|
||||
assert component.api_key == "test-api-key"
|
||||
assert component.model_name == "command"
|
||||
assert component.streaming_callback is None
|
||||
assert component.api_base_url == cohere.COHERE_API_URL
|
||||
assert component.model_parameters == {}
|
||||
|
||||
def test_init_with_parameters(self):
|
||||
callback = lambda x: x
|
||||
component = CohereGenerator(
|
||||
api_key="test-api-key",
|
||||
model_name="command-light",
|
||||
max_tokens=10,
|
||||
some_test_param="test-params",
|
||||
streaming_callback=callback,
|
||||
api_base_url="test-base-url",
|
||||
)
|
||||
assert component.api_key == "test-api-key"
|
||||
assert component.model_name == "command-light"
|
||||
assert component.streaming_callback == callback
|
||||
assert component.api_base_url == "test-base-url"
|
||||
assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"}
|
||||
|
||||
def test_to_dict_default(self):
|
||||
import cohere
|
||||
|
||||
component = CohereGenerator(api_key="test-api-key")
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "haystack.components.generators.cohere.CohereGenerator",
|
||||
"init_parameters": {
|
||||
"model_name": "command",
|
||||
"streaming_callback": None,
|
||||
"api_base_url": cohere.COHERE_API_URL,
|
||||
},
|
||||
}
|
||||
|
||||
def test_to_dict_with_parameters(self):
|
||||
component = CohereGenerator(
|
||||
api_key="test-api-key",
|
||||
model_name="command-light",
|
||||
max_tokens=10,
|
||||
some_test_param="test-params",
|
||||
streaming_callback=default_streaming_callback,
|
||||
api_base_url="test-base-url",
|
||||
)
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "haystack.components.generators.cohere.CohereGenerator",
|
||||
"init_parameters": {
|
||||
"model_name": "command-light",
|
||||
"max_tokens": 10,
|
||||
"some_test_param": "test-params",
|
||||
"api_base_url": "test-base-url",
|
||||
"streaming_callback": "test_cohere_generators.default_streaming_callback",
|
||||
},
|
||||
}
|
||||
|
||||
def test_to_dict_with_lambda_streaming_callback(self):
|
||||
component = CohereGenerator(
|
||||
api_key="test-api-key",
|
||||
model_name="command",
|
||||
max_tokens=10,
|
||||
some_test_param="test-params",
|
||||
streaming_callback=lambda x: x,
|
||||
api_base_url="test-base-url",
|
||||
)
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "haystack.components.generators.cohere.CohereGenerator",
|
||||
"init_parameters": {
|
||||
"model_name": "command",
|
||||
"streaming_callback": "test_cohere_generators.<lambda>",
|
||||
"api_base_url": "test-base-url",
|
||||
"max_tokens": 10,
|
||||
"some_test_param": "test-params",
|
||||
},
|
||||
}
|
||||
|
||||
def test_from_dict(self, monkeypatch):
|
||||
monkeypatch.setenv("COHERE_API_KEY", "test-key")
|
||||
data = {
|
||||
"type": "haystack.components.generators.cohere.CohereGenerator",
|
||||
"init_parameters": {
|
||||
"model_name": "command",
|
||||
"max_tokens": 10,
|
||||
"some_test_param": "test-params",
|
||||
"api_base_url": "test-base-url",
|
||||
"streaming_callback": "test_cohere_generators.default_streaming_callback",
|
||||
},
|
||||
}
|
||||
component = CohereGenerator.from_dict(data)
|
||||
assert component.api_key == "test-key"
|
||||
assert component.model_name == "command"
|
||||
assert component.streaming_callback == default_streaming_callback
|
||||
assert component.api_base_url == "test-base-url"
|
||||
assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"}
|
||||
|
||||
def test_check_truncated_answers(self, caplog):
|
||||
component = CohereGenerator(api_key="test-api-key")
|
||||
metadata = [{"finish_reason": "MAX_TOKENS"}]
|
||||
component._check_truncated_answers(metadata)
|
||||
assert caplog.records[0].message == (
|
||||
"Responses have been truncated before reaching a natural stopping point. "
|
||||
"Increase the max_tokens parameter to allow for longer completions."
|
||||
)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("COHERE_API_KEY", None),
|
||||
reason="Export an env var called CO_API_KEY containing the Cohere API key to run this test.",
|
||||
)
|
||||
@pytest.mark.integration
|
||||
def test_cohere_generator_run(self):
|
||||
component = CohereGenerator(api_key=os.environ.get("COHERE_API_KEY"))
|
||||
results = component.run(prompt="What's the capital of France?")
|
||||
assert len(results["replies"]) == 1
|
||||
assert "Paris" in results["replies"][0]
|
||||
assert len(results["metadata"]) == 1
|
||||
assert results["metadata"][0]["finish_reason"] == "COMPLETE"
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("COHERE_API_KEY", None),
|
||||
reason="Export an env var called COHERE_API_KEY containing the Cohere API key to run this test.",
|
||||
)
|
||||
@pytest.mark.integration
|
||||
def test_cohere_generator_run_wrong_model_name(self):
|
||||
import cohere
|
||||
|
||||
component = CohereGenerator(model_name="something-obviously-wrong", api_key=os.environ.get("COHERE_API_KEY"))
|
||||
with pytest.raises(
|
||||
cohere.CohereAPIError,
|
||||
match="model not found, make sure the correct model ID was used and that you have access to the model.",
|
||||
):
|
||||
component.run(prompt="What's the capital of France?")
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("COHERE_API_KEY", None),
|
||||
reason="Export an env var called COHERE_API_KEY containing the Cohere API key to run this test.",
|
||||
)
|
||||
@pytest.mark.integration
|
||||
def test_cohere_generator_run_streaming(self):
|
||||
class Callback:
|
||||
def __init__(self):
|
||||
self.responses = ""
|
||||
|
||||
def __call__(self, chunk):
|
||||
self.responses += chunk.text
|
||||
return chunk
|
||||
|
||||
callback = Callback()
|
||||
component = CohereGenerator(os.environ.get("COHERE_API_KEY"), streaming_callback=callback)
|
||||
results = component.run(prompt="What's the capital of France?")
|
||||
|
||||
assert len(results["replies"]) == 1
|
||||
assert "Paris" in results["replies"][0]
|
||||
assert len(results["metadata"]) == 1
|
||||
assert results["metadata"][0]["finish_reason"] == "COMPLETE"
|
||||
assert callback.responses == results["replies"][0]
|
Loading…
x
Reference in New Issue
Block a user