feat(TokenTracker): Add context manager support to simplify token tracking

This commit is contained in:
choizhang 2025-03-30 00:59:23 +08:00
parent 6eea8bdf5d
commit 164faf94e2
3 changed files with 31 additions and 30 deletions

View File

@ -115,38 +115,36 @@ def main():
# Initialize RAG instance # Initialize RAG instance
rag = asyncio.run(initialize_rag()) rag = asyncio.run(initialize_rag())
# Reset tracker before processing queries
token_tracker.reset()
with open("./book.txt", "r", encoding="utf-8") as f: with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read()) rag.insert(f.read())
print( # Context Manager Method
rag.query( with token_tracker:
"What are the top themes in this story?", param=QueryParam(mode="naive") print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="naive")
)
) )
)
print( print(
rag.query( rag.query(
"What are the top themes in this story?", param=QueryParam(mode="local") "What are the top themes in this story?", param=QueryParam(mode="local")
)
) )
)
print( print(
rag.query( rag.query(
"What are the top themes in this story?", param=QueryParam(mode="global") "What are the top themes in this story?",
param=QueryParam(mode="global"),
)
) )
)
print( print(
rag.query( rag.query(
"What are the top themes in this story?", param=QueryParam(mode="hybrid") "What are the top themes in this story?",
param=QueryParam(mode="hybrid"),
)
) )
)
# Display final token usage after main query
print("Token usage:", token_tracker.get_usage())
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -44,14 +44,10 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
# function test # function test
async def test_funcs(): async def test_funcs():
# Reset tracker before processing queries # Context Manager Method
token_tracker.reset() with token_tracker:
result = await llm_model_func("How are you?")
result = await llm_model_func("How are you?") print("llm_model_func: ", result)
print("llm_model_func: ", result)
# Display final token usage after main query
print("Token usage:", token_tracker.get_usage())
asyncio.run(test_funcs()) asyncio.run(test_funcs())

View File

@ -962,6 +962,13 @@ class TokenTracker:
def __init__(self): def __init__(self):
self.reset() self.reset()
def __enter__(self):
self.reset()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
print(self)
def reset(self): def reset(self):
self.prompt_tokens = 0 self.prompt_tokens = 0
self.completion_tokens = 0 self.completion_tokens = 0