mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-10-30 17:29:29 +00:00 
			
		
		
		
	Fix for custom template in OpenAIAnswerGenerator (#4220)
This commit is contained in:
		
							parent
							
								
									c4b98fcccc
								
							
						
					
					
						commit
						2bedb80ba5
					
				| @ -138,7 +138,7 @@ class OpenAIAnswerGenerator(BaseGenerator): | |||||||
|         else: |         else: | ||||||
|             # Check for required prompts |             # Check for required prompts | ||||||
|             required_params = ["context", "query"] |             required_params = ["context", "query"] | ||||||
|             if all(p in prompt_template.prompt_params for p in required_params): |             if not all(p in prompt_template.prompt_params for p in required_params): | ||||||
|                 raise ValueError( |                 raise ValueError( | ||||||
|                     "The OpenAIAnswerGenerator requires a PromptTemplate that has `context` and " |                     "The OpenAIAnswerGenerator requires a PromptTemplate that has `context` and " | ||||||
|                     "`query` in its `prompt_params`. Supply a different `prompt_template` or " |                     "`query` in its `prompt_params`. Supply a different `prompt_template` or " | ||||||
|  | |||||||
| @ -7,6 +7,7 @@ import pytest | |||||||
| from haystack.schema import Document | from haystack.schema import Document | ||||||
| from haystack.nodes.answer_generator import Seq2SeqGenerator, OpenAIAnswerGenerator | from haystack.nodes.answer_generator import Seq2SeqGenerator, OpenAIAnswerGenerator | ||||||
| from haystack.pipelines import TranslationWrapperPipeline, GenerativeQAPipeline | from haystack.pipelines import TranslationWrapperPipeline, GenerativeQAPipeline | ||||||
|  | from haystack.nodes import PromptTemplate | ||||||
| 
 | 
 | ||||||
| import logging | import logging | ||||||
| 
 | 
 | ||||||
| @ -142,6 +143,26 @@ def test_openai_answer_generator(openai_generator, docs): | |||||||
|     assert "Carla" in prediction["answers"][0].answer |     assert "Carla" in prediction["answers"][0].answer | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @pytest.mark.integration | ||||||
|  | @pytest.mark.skipif( | ||||||
|  |     not os.environ.get("OPENAI_API_KEY", None), | ||||||
|  |     reason="No OpenAI API key provided. Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", | ||||||
|  | ) | ||||||
|  | def test_openai_answer_generator_custom_template(docs): | ||||||
|  |     lfqa_prompt = PromptTemplate( | ||||||
|  |         name="lfqa", | ||||||
|  |         prompt_text=""" | ||||||
|  |         Synthesize a comprehensive answer from your knowledge and the following topk most relevant paragraphs and the given question. | ||||||
|  |         \n===\Paragraphs: $context\n===\n$query""", | ||||||
|  |         prompt_params=["context", "query"], | ||||||
|  |     ) | ||||||
|  |     node = OpenAIAnswerGenerator( | ||||||
|  |         api_key=os.environ.get("OPENAI_API_KEY", ""), model="text-babbage-001", top_k=1, prompt_template=lfqa_prompt | ||||||
|  |     ) | ||||||
|  |     prediction = node.predict(query="Who lives in Berlin?", documents=docs, top_k=1) | ||||||
|  |     assert len(prediction["answers"]) == 1 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| @pytest.mark.integration | @pytest.mark.integration | ||||||
| @pytest.mark.skipif( | @pytest.mark.skipif( | ||||||
|     not os.environ.get("OPENAI_API_KEY", None), |     not os.environ.get("OPENAI_API_KEY", None), | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Sebastian
						Sebastian