mirror of
https://github.com/HKUDS/LightRAG.git
synced 2025-06-26 22:00:19 +00:00
feat(TokenTracker): Add context manager support to simplify token tracking
This commit is contained in:
parent
6eea8bdf5d
commit
164faf94e2
@ -115,38 +115,36 @@ def main():
|
|||||||
# Initialize RAG instance
|
# Initialize RAG instance
|
||||||
rag = asyncio.run(initialize_rag())
|
rag = asyncio.run(initialize_rag())
|
||||||
|
|
||||||
# Reset tracker before processing queries
|
|
||||||
token_tracker.reset()
|
|
||||||
|
|
||||||
with open("./book.txt", "r", encoding="utf-8") as f:
|
with open("./book.txt", "r", encoding="utf-8") as f:
|
||||||
rag.insert(f.read())
|
rag.insert(f.read())
|
||||||
|
|
||||||
print(
|
# Context Manager Method
|
||||||
rag.query(
|
with token_tracker:
|
||||||
"What are the top themes in this story?", param=QueryParam(mode="naive")
|
print(
|
||||||
|
rag.query(
|
||||||
|
"What are the top themes in this story?", param=QueryParam(mode="naive")
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
print(
|
print(
|
||||||
rag.query(
|
rag.query(
|
||||||
"What are the top themes in this story?", param=QueryParam(mode="local")
|
"What are the top themes in this story?", param=QueryParam(mode="local")
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
print(
|
print(
|
||||||
rag.query(
|
rag.query(
|
||||||
"What are the top themes in this story?", param=QueryParam(mode="global")
|
"What are the top themes in this story?",
|
||||||
|
param=QueryParam(mode="global"),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
print(
|
print(
|
||||||
rag.query(
|
rag.query(
|
||||||
"What are the top themes in this story?", param=QueryParam(mode="hybrid")
|
"What are the top themes in this story?",
|
||||||
|
param=QueryParam(mode="hybrid"),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# Display final token usage after main query
|
|
||||||
print("Token usage:", token_tracker.get_usage())
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -44,14 +44,10 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
|
|||||||
|
|
||||||
# function test
|
# function test
|
||||||
async def test_funcs():
|
async def test_funcs():
|
||||||
# Reset tracker before processing queries
|
# Context Manager Method
|
||||||
token_tracker.reset()
|
with token_tracker:
|
||||||
|
result = await llm_model_func("How are you?")
|
||||||
result = await llm_model_func("How are you?")
|
print("llm_model_func: ", result)
|
||||||
print("llm_model_func: ", result)
|
|
||||||
|
|
||||||
# Display final token usage after main query
|
|
||||||
print("Token usage:", token_tracker.get_usage())
|
|
||||||
|
|
||||||
|
|
||||||
asyncio.run(test_funcs())
|
asyncio.run(test_funcs())
|
||||||
|
@ -962,6 +962,13 @@ class TokenTracker:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.reset()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
print(self)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.prompt_tokens = 0
|
self.prompt_tokens = 0
|
||||||
self.completion_tokens = 0
|
self.completion_tokens = 0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user