fix: Add PromptTemplate __repr__ method (#4058)

Co-authored-by: ZanSara <sarazanzo94@gmail.com>
This commit is contained in:
Vladimir Blagojevic 2023-02-07 08:14:32 +01:00 committed by GitHub
parent a9f13d4641
commit 3273a2714d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 0 deletions

View File

@ -173,6 +173,9 @@ class PromptTemplate(BasePromptTemplate, ABC):
prompt_prepared: str = template.substitute(template_input)
yield prompt_prepared
def __repr__(self):
return f"PromptTemplate(name={self.name}, prompt_text={self.prompt_text}, prompt_params={self.prompt_params})"
class PromptModelInvocationLayer:
"""

View File

@ -51,6 +51,13 @@ def test_prompt_templates():
assert p.prompt_text == "Here is some fake template with variable $baz"
def test_prompt_template_repr():
p = PromptTemplate("t", "Here is variable $baz")
desired_repr = "PromptTemplate(name=t, prompt_text=Here is variable $baz, prompt_params=['baz'])"
assert repr(p) == desired_repr
assert str(p) == desired_repr
def test_create_prompt_model():
model = PromptModel("google/flan-t5-small")
assert model.model_name_or_path == "google/flan-t5-small"