| 
									
										
										
										
											2023-03-13 16:17:07 +01:00
										 |  |  | import os | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import pytest | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-03-28 13:52:21 +02:00
										 |  |  | from haystack import Pipeline | 
					
						
							| 
									
										
										
										
											2023-03-13 16:17:07 +01:00
										 |  |  | from haystack.nodes.audio import WhisperTranscriber | 
					
						
							|  |  |  | from haystack.utils.import_utils import is_whisper_available | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @pytest.mark.skipif(os.environ.get("OPENAI_API_KEY", "") == "", reason="OpenAI API key not found") | 
					
						
							|  |  |  | @pytest.mark.integration | 
					
						
							| 
									
										
										
										
											2023-04-11 10:33:43 +02:00
										 |  |  | def test_whisper_api_transcribe(samples_path): | 
					
						
							| 
									
										
										
										
											2023-03-13 16:17:07 +01:00
										 |  |  |     w = WhisperTranscriber(api_key=os.environ.get("OPENAI_API_KEY")) | 
					
						
							| 
									
										
										
										
											2023-04-11 10:33:43 +02:00
										 |  |  |     audio_object_transcript, audio_path_transcript = transcribe_test_helper(w, samples_path=samples_path) | 
					
						
							| 
									
										
										
										
											2023-03-13 16:17:07 +01:00
										 |  |  |     assert "segments" not in audio_object_transcript and "segments" not in audio_path_transcript | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-03-15 15:26:30 +01:00
										 |  |  | @pytest.mark.skip("Fails on CI cause it fills up memory") | 
					
						
							| 
									
										
										
										
											2023-03-13 16:17:07 +01:00
										 |  |  | @pytest.mark.integration | 
					
						
							|  |  |  | @pytest.mark.skipif(not is_whisper_available(), reason="Whisper is not installed") | 
					
						
							| 
									
										
										
										
											2023-04-11 10:33:43 +02:00
										 |  |  | def test_whisper_local_transcribe(samples_path): | 
					
						
							| 
									
										
										
										
											2023-03-13 16:17:07 +01:00
										 |  |  |     w = WhisperTranscriber() | 
					
						
							| 
									
										
										
										
											2023-04-11 10:33:43 +02:00
										 |  |  |     audio_object_transcript, audio_path_transcript = transcribe_test_helper(w, samples_path=samples_path, language="en") | 
					
						
							| 
									
										
										
										
											2023-03-13 16:17:07 +01:00
										 |  |  |     assert "segments" not in audio_object_transcript and "segments" not in audio_path_transcript | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-03-15 15:26:30 +01:00
										 |  |  | @pytest.mark.skip("Fails on CI cause it fills up memory") | 
					
						
							| 
									
										
										
										
											2023-03-13 16:17:07 +01:00
										 |  |  | @pytest.mark.integration | 
					
						
							|  |  |  | @pytest.mark.skipif(not is_whisper_available(), reason="Whisper is not installed") | 
					
						
							| 
									
										
										
										
											2023-04-11 10:33:43 +02:00
										 |  |  | def test_whisper_local_transcribe_with_params(samples_path): | 
					
						
							| 
									
										
										
										
											2023-03-13 16:17:07 +01:00
										 |  |  |     w = WhisperTranscriber() | 
					
						
							| 
									
										
										
										
											2023-04-11 10:33:43 +02:00
										 |  |  |     audio_object, audio_path = transcribe_test_helper(w, samples_path=samples_path, language="en", return_segments=True) | 
					
						
							| 
									
										
										
										
											2023-03-13 16:17:07 +01:00
										 |  |  |     assert len(audio_object["segments"]) == 1 and len(audio_path["segments"]) == 1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-11 10:33:43 +02:00
										 |  |  | def transcribe_test_helper(whisper, samples_path, **kwargs): | 
					
						
							| 
									
										
										
										
											2023-03-13 16:17:07 +01:00
										 |  |  |     # this file is 1 second long and contains the word "answer" | 
					
						
							| 
									
										
										
										
											2023-04-11 10:33:43 +02:00
										 |  |  |     file_path = str(samples_path / "audio" / "answer.wav") | 
					
						
							| 
									
										
										
										
											2023-03-13 16:17:07 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # using audio object | 
					
						
							|  |  |  |     with open(file_path, mode="rb") as audio_file: | 
					
						
							|  |  |  |         audio_object_transcript = whisper.transcribe(audio_file=audio_file, **kwargs) | 
					
						
							|  |  |  |         assert "answer" in audio_object_transcript["text"].lower() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # using path to audio file | 
					
						
							|  |  |  |     audio_path_transcript = whisper.transcribe(audio_file=file_path, **kwargs) | 
					
						
							|  |  |  |     assert "answer" in audio_path_transcript["text"].lower() | 
					
						
							|  |  |  |     return audio_object_transcript, audio_path_transcript | 
					
						
							| 
									
										
										
										
											2023-03-28 13:52:21 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @pytest.mark.skipif(os.environ.get("OPENAI_API_KEY", "") == "", reason="OpenAI API key not found") | 
					
						
							|  |  |  | @pytest.mark.integration | 
					
						
							| 
									
										
										
										
											2023-04-11 10:33:43 +02:00
										 |  |  | def test_whisper_pipeline(samples_path): | 
					
						
							| 
									
										
										
										
											2023-03-28 13:52:21 +02:00
										 |  |  |     w = WhisperTranscriber(api_key=os.environ.get("OPENAI_API_KEY")) | 
					
						
							|  |  |  |     pipeline = Pipeline() | 
					
						
							|  |  |  |     pipeline.add_node(component=w, name="whisper", inputs=["File"]) | 
					
						
							| 
									
										
										
										
											2023-04-11 10:33:43 +02:00
										 |  |  |     res = pipeline.run(file_paths=[str(samples_path / "audio" / "answer.wav")]) | 
					
						
							| 
									
										
										
										
											2023-03-28 13:52:21 +02:00
										 |  |  |     assert res["documents"] and len(res["documents"]) == 1 | 
					
						
							|  |  |  |     assert "answer" in res["documents"][0].content.lower() |