llama-hub/tests/test_github_reader.py
2023-03-01 03:38:39 -05:00

844 lines
27 KiB
Python

from llama_index import Document
import httpx
import pytest
import asyncio
import base64
import os
from unittest.mock import MagicMock, AsyncMock, call
import unittest
from typing import List, Tuple
# Remove this to test changes to GithubRepositoryReader.
pytest.skip("Skip by default due to network request.", allow_module_level=True)
from loader_hub.github_repo.utils import (
BufferedAsyncIterator,
BufferedGitBlobDataIterator,
)
from loader_hub.github_repo.github_client import (
GithubClient,
GitBlobResponseModel,
GitTreeResponseModel,
)
from loader_hub.github_repo.base import GithubRepositoryReader
## Test BufferedAsyncIterator ##
## and BufferedGitBlobDataIterator ##
class MockGithubClient:
async def get_blob(self, owner, repo, sha):
return f"base64-decoded string blob content {owner}/{repo}/{sha}"
@pytest.mark.asyncio
async def test_buffered_async_iterator():
class TestIterator(BufferedAsyncIterator):
def __init__(self, data: List[Tuple[str, str]], buffer_size: int = 2):
super().__init__(buffer_size)
self._data = data
async def _fill_buffer(self):
del self._buffer[:]
self._buffer = []
start = self._index
end = min(start + self._buffer_size, len(self._data))
if start >= end:
return
self._buffer = self._data[start:end]
data = [
("my-sha-1", "my/path1"),
("my-sha-2", "my/path2"),
("my-sha-3", "my/path3"),
("my-sha-4", "my/path4"),
("my-sha-5", "my/path5"),
("my-sha-6", "my/path6"),
]
iterator = TestIterator(data, buffer_size=2)
assert len(iterator._buffer) == 0
assert iterator._index == 0
assert iterator._buffer_size == 2
assert await iterator.__anext__() == ("my-sha-1", "my/path1")
assert len(iterator._buffer) == 1
assert iterator._index == 1
assert await iterator.__anext__() == ("my-sha-2", "my/path2")
assert len(iterator._buffer) == 0
assert iterator._index == 2
assert await iterator.__anext__() == ("my-sha-3", "my/path3")
assert len(iterator._buffer) == 1
assert iterator._index == 3
assert await iterator.__anext__() == ("my-sha-4", "my/path4")
assert len(iterator._buffer) == 0
assert iterator._index == 4
assert await iterator.__anext__() == ("my-sha-5", "my/path5")
assert len(iterator._buffer) == 1
assert iterator._index == 5
assert await iterator.__anext__() == ("my-sha-6", "my/path6")
assert len(iterator._buffer) == 0
assert iterator._index == 6
with pytest.raises(StopAsyncIteration):
await iterator.__anext__()
@pytest.mark.asyncio
async def test_buffered_git_blob_data_iterator():
github_client = MockGithubClient()
owner = "my-owner"
repo = "my-repo"
loop = asyncio.get_event_loop()
blobs_and_paths = [
(
GitTreeResponseModel.GitTreeObject(
sha="my-sha-1",
path="file1",
mode="100644",
type="blob",
size=123,
url="https://api.github.com/repos/octocat/Hello-World/git/blobs/my-sha-1",
),
"path/file1",
),
(
GitTreeResponseModel.GitTreeObject(
sha="my-sha-2",
path="file2",
mode="100644",
type="blob",
size=321,
url="https://api.github.com/repos/octocat/Hello-World/git/blobs/my-sha-2",
),
"path/file2",
),
(
GitTreeResponseModel.GitTreeObject(
sha="my-sha-3",
path="file3",
mode="100644",
type="blob",
size=456,
url="https://api.github.com/repos/octocat/Hello-World/git/blobs/my-sha-3",
),
"path/to/file3",
),
(
GitTreeResponseModel.GitTreeObject(
sha="my-sha-4",
path="file4",
mode="100644",
type="blob",
size=941,
url="https://api.github.com/repos/octocat/Hello-World/git/blobs/my-sha-4",
),
"path/to/file4",
),
]
it = BufferedGitBlobDataIterator(
blobs_and_paths,
github_client,
owner,
repo,
loop,
buffer_size=3,
verbose=False,
)
assert len(it._buffer) == 0
assert it._index == 0
assert it._buffer_size == 3
assert await it.__anext__() == (
f"base64-decoded string blob content {owner}/{repo}/my-sha-1",
"path/file1",
)
assert len(it._buffer) == 2
assert it._index == 1
assert await it.__anext__() == (
f"base64-decoded string blob content {owner}/{repo}/my-sha-2",
"path/file2",
)
assert len(it._buffer) == 1
assert it._index == 2
assert await it.__anext__() == (
f"base64-decoded string blob content {owner}/{repo}/my-sha-3",
"path/to/file3",
)
assert len(it._buffer) == 0
assert it._index == 3
assert await it.__anext__() == (
f"base64-decoded string blob content {owner}/{repo}/my-sha-4",
"path/to/file4",
)
assert len(it._buffer) == 0
assert it._index == 4
with pytest.raises(StopAsyncIteration):
await it.__anext__()
########################
## GithubClient tests ##
@pytest.fixture
def github_client():
return GithubClient(
github_token=os.getenv("GITHUB_API_TOKEN"),
verbose=True,
)
@pytest.mark.asyncio
async def test_github_client(github_client):
owner = "emptycrown"
repo = "llama-hub"
branch = "main"
commit_sha = "0cd691322e5244b48b68e3588d1343eb53f3a112" # Points to Add spotify reader, https://github.com/emptycrown/llama-hub/commit/0cd691322e5244b48b68e3588d1343eb53f3a112
# test get_branch
branch_data = await github_client.get_branch(owner, repo, branch)
assert branch_data.name == branch
assert (
branch_data._links.self
== f"https://api.github.com/repos/{owner}/{repo}/branches/{branch}"
), "Branch self link is incorrect"
assert (
branch_data._links.html
== f"https://github.com/{owner}/{repo}/tree/{branch}"
), "Branch html link is incorrect"
# test get_commit
commit_data = await github_client.get_commit(owner, repo, commit_sha)
assert commit_data.sha == commit_sha, "Commit sha is incorrect"
assert (
commit_data.url
== f"https://api.github.com/repos/{owner}/{repo}/commits/{commit_sha}"
), "Commit url is incorrect"
# test get_tree
tree_data = await github_client.get_tree(
owner, repo, commit_data.commit.tree.sha
)
assert (
tree_data.url
== f"https://api.github.com/repos/{owner}/{repo}/git/trees/{commit_data.commit.tree.sha}"
), "Tree url is incorrect"
assert (
tree_data.sha == commit_data.commit.tree.sha
), "Tree sha is incorrect"
print(tree_data.tree[0].sha)
assert 1 == 1
# test get_blob
expected_files_in_first_depth_of_the_tree: List[Tuple[str, str]] = [
("test_requirements.txt", "blob"),
("README.md", "blob"),
("Makefile", "blob"),
(".gitignore", "blob"),
("tests", "tree"),
("loader_hub", "tree"),
(".github", "tree"),
]
# check if the first depth of the tree has the expected files. All the expected files should be in the first depth of the tree and vice versa
assert len(tree_data.tree) == len(
expected_files_in_first_depth_of_the_tree
), "The number of files in the first depth of the tree is incorrect"
for file in expected_files_in_first_depth_of_the_tree:
assert file in [
(tree_file.path, tree_file.type) for tree_file in tree_data.tree
], f"{file} is not in the first depth of the tree"
# checking the opposite
for tree_obj in tree_data.tree:
assert (
tree_obj.path,
tree_obj.type,
) in expected_files_in_first_depth_of_the_tree, (
f"{tree_obj.path} is not in the expected files"
)
# find test_reqirements.txt in the tree
test_requirements_txt = [
tree_obj
for tree_obj in tree_data.tree
if tree_obj.path == "test_requirements.txt"
][0]
# test get_blob
blob_data = await github_client.get_blob(
owner, repo, test_requirements_txt.sha
)
assert blob_data.encoding == "base64", "Blob encoding is incorrect"
assert (
blob_data.url
== f"https://api.github.com/repos/{owner}/{repo}/git/blobs/{test_requirements_txt.sha}"
), "Blob url is incorrect"
assert blob_data.sha == test_requirements_txt.sha, "Blob sha is incorrect"
# decode blob content base64-decoded string to utf-8
decoded_blob_content = base64.b64decode(blob_data.content).decode("utf-8")
expected_decoded_blob_content = """
# For testing
pytest==7.2.1
pytest-dotenv==0.5.2
# TODO: remove gpt_index after migration
https://github.com/jerryjliu/gpt_index/archive/master.zip
llama-index
# For linting
# linting stubs
types-requests==2.28.11.8
# formatting
black==22.12.0
isort==5.11.4
"""
# check if the decoded blob content is correct
for dbc in zip(
filter(lambda x: x != "", decoded_blob_content.splitlines()),
filter(lambda x: x != "", expected_decoded_blob_content.splitlines()),
):
assert dbc[0] == dbc[1], f"{dbc[0]} is not equal to {dbc[1]}"
class TestGithubRepositoryReader(unittest.TestCase):
def setUp(self):
self.github_client = MagicMock()
self.owner = "owner"
self.repo = "repo"
self.reader = GithubRepositoryReader(
self.github_client, self.owner, self.repo
)
def test__check_filter_directories(self):
tree_obj_path = "path/to/some/file.py"
self.reader._filter_directories = (
["path/to"],
GithubRepositoryReader.FilterType.INCLUDE,
)
self.assertTrue(self.reader._check_filter_directories(tree_obj_path))
self.reader._filter_directories = (
["path/to"],
GithubRepositoryReader.FilterType.EXCLUDE,
)
self.assertFalse(self.reader._check_filter_directories(tree_obj_path))
def test__check_filter_file_extensions(self):
tree_obj_path = "path/to/some/file.py"
self.reader._filter_file_extensions = (
[".py"],
GithubRepositoryReader.FilterType.INCLUDE,
)
self.assertTrue(
self.reader._check_filter_file_extensions(tree_obj_path)
)
self.reader._filter_file_extensions = (
[".txt"],
GithubRepositoryReader.FilterType.EXCLUDE,
)
self.assertTrue(
self.reader._check_filter_file_extensions(tree_obj_path)
)
def test__allow_tree_obj(self):
tree_obj_paths = [
"src/file.py",
"src/file.txt",
"src/dir1/file.js",
"src/assets/file.png",
"src/assets/file.jpg",
"src/assets/file.jpeg",
"src/assets/file.gif",
"src/assets/file.svg",
"src/assets/file.ico",
"src/documents/file.pdf",
"src/documents/file.doc",
"src/documents/file.docx",
"src/documents/file.xls",
"src/documents/file.xlsx",
"src/documents/file.ppt",
"src/documents/file.pptx",
"src/documents/file.odt",
"src/documents/file.ods",
"src/dir2/subdir/file.cpp",
"src/dir2/subdir/file.c",
"src/dir2/subdir/file.h",
"src/dir2/subdir/file.hpp",
"src/dir2/subdir/file.java",
"src/dir2/foo.cc",
"src/dir2/foo.svg",
"src/dir2/subdir/file.go",
"src/sub/folder/loading.svg",
"src/sub/folder/loading.ico",
]
self.reader._filter_directories = (
["src/assets", "src/documents"],
GithubRepositoryReader.FilterType.EXCLUDE,
)
self.reader._filter_file_extensions = (
[".svg", ".ico", ".cpp", ".c", ".h"],
GithubRepositoryReader.FilterType.EXCLUDE,
)
expected_tree_obj_paths = [
"src/file.py",
"src/file.txt",
"src/dir1/file.js",
"src/dir2/subdir/file.hpp",
"src/dir2/subdir/file.java",
"src/dir2/foo.cc",
"src/dir2/subdir/file.go",
]
actual_tree_obj_paths = [
tree_obj_path
for tree_obj_path in tree_obj_paths
if self.reader._allow_tree_obj(tree_obj_path)
]
print(f"Expected: {expected_tree_obj_paths}")
print(f"Actual: {actual_tree_obj_paths}")
self.assertCountEqual(
expected_tree_obj_paths, actual_tree_obj_paths
), "Tree object paths are incorrect"
self.reader._filter_directories = (
["src/dir2/subdir", "src/documents"],
GithubRepositoryReader.FilterType.INCLUDE,
)
self.reader._filter_file_extensions = (
[".png", ".svg", ".ico", "jpg", ".java"],
GithubRepositoryReader.FilterType.EXCLUDE,
)
expected_tree_obj_paths = [
"src/documents/file.pdf",
"src/documents/file.doc",
"src/documents/file.docx",
"src/documents/file.xls",
"src/documents/file.xlsx",
"src/documents/file.ppt",
"src/documents/file.pptx",
"src/documents/file.odt",
"src/documents/file.ods",
"src/dir2/subdir/file.cpp",
"src/dir2/subdir/file.c",
"src/dir2/subdir/file.h",
"src/dir2/subdir/file.hpp",
"src/dir2/subdir/file.go",
]
actual_tree_obj_paths = [
tree_obj_path
for tree_obj_path in tree_obj_paths
if self.reader._allow_tree_obj(tree_obj_path)
]
print(f"Expected: {expected_tree_obj_paths}")
print(f"Actual: {actual_tree_obj_paths}")
self.assertCountEqual(
expected_tree_obj_paths, actual_tree_obj_paths
), "Tree object paths are incorrect"
## Test GithubRepositoryReader's _recurse_tree method
@pytest.mark.asyncio
async def test__recurse_tree():
github_client = MagicMock()
owner = "owner"
repo = "repo"
reader = GithubRepositoryReader(github_client, owner, repo)
# return value for the first call to get_tree (the root tree)
tree_sha = "1234"
tree_data = GitTreeResponseModel(
sha=tree_sha,
tree=[
GitTreeResponseModel.GitTreeObject(
type="blob",
path="file1.py",
sha="5678",
mode="100644",
size=1111,
url="https://api.github.com/repos/owner/repo/git/blobs/5678",
),
GitTreeResponseModel.GitTreeObject(
type="tree",
path="folder1",
sha="91011",
mode="040000",
size=None,
url="https://api.github.com/repos/owner/repo/git/blobs/91011",
),
GitTreeResponseModel.GitTreeObject(
type="blob",
path="file2.py",
sha="1213",
mode="100644",
size=3333,
url="https://api.github.com/repos/owner/repo/git/blobs/1213",
),
GitTreeResponseModel.GitTreeObject(
type="blob",
path="file3.py",
sha="1415",
mode="100644",
size=4444,
url="https://api.github.com/repos/owner/repo/git/blobs/1415",
),
],
truncated=False,
url="https://api.github.com/repos/owner/repo/git/trees/1234",
)
def get_tree_side_effect(owner, repo, sha):
if sha == tree_sha:
return tree_data
elif sha == "91011":
# return value for the second call to get_tree (the tree for folder1)
return GitTreeResponseModel(
tree=[
GitTreeResponseModel.GitTreeObject(
type="blob",
path="file4.py",
sha="1617",
mode="100644",
size=6666,
url="https://api.github.com/repos/owner/repo/git/blobs/1617",
),
GitTreeResponseModel.GitTreeObject(
type="tree",
path="folder3",
sha="1819",
mode="040000",
size=None,
url="https://api.github.com/repos/owner/repo/git/blobs/1819",
),
],
sha="91011",
truncated=False,
url="https://api.github.com/repos/owner/repo/git/trees/91011",
)
elif sha == "1819":
# return value for the third call to get_tree (the tree for folder3)
return GitTreeResponseModel(
tree=[
GitTreeResponseModel.GitTreeObject(
type="blob",
path="file5.py",
sha="2021",
mode="100644",
size=8888,
url="https://api.github.com/repos/owner/repo/git/blobs/2021",
),
],
sha="1819",
truncated=False,
url="https://api.github.com/repos/owner/repo/git/trees/1819",
)
else:
raise httpx.HTTPError(
f"404 Client Error: Not Found for url: https://api.github.com/repos/{owner}/{repo}/git/trees/{sha}"
)
github_client.get_tree = AsyncMock(side_effect=get_tree_side_effect)
expected_blobs_and_full_paths = [
(
GitTreeResponseModel.GitTreeObject(
type="blob",
path="file1.py",
sha="5678",
mode="100644",
size=1111,
url="https://api.github.com/repos/owner/repo/git/blobs/5678",
),
"file1.py",
),
(
GitTreeResponseModel.GitTreeObject(
type="blob",
path="file2.py",
sha="1213",
mode="100644",
size=3333,
url="https://api.github.com/repos/owner/repo/git/blobs/1213",
),
"file2.py",
),
(
GitTreeResponseModel.GitTreeObject(
type="blob",
path="file4.py",
sha="1617",
mode="100644",
size=6666,
url="https://api.github.com/repos/owner/repo/git/blobs/1617",
),
"folder1/file4.py",
),
(
GitTreeResponseModel.GitTreeObject(
type="blob",
path="file5.py",
sha="2021",
mode="100644",
size=8888,
url="https://api.github.com/repos/owner/repo/git/blobs/2021",
),
"folder1/folder3/file5.py",
),
(
GitTreeResponseModel.GitTreeObject(
type="blob",
path="file3.py",
sha="1415",
mode="100644",
size=4444,
url="https://api.github.com/repos/owner/repo/git/blobs/1415",
),
"file3.py",
),
]
blobs_and_full_paths = await reader._recurse_tree(tree_sha)
# make sure get_tree was called the expected number of times
assert (
github_client.get_tree.call_count == 3
), "There should be only 3 calls to get_tree (one for the root tree, and one for each subfolder folder1 and folder3)"
# sort the expected and actual results by full path so we can compare them
for (blob, full_path), (expected_blob, expected_full_path) in zip(
sorted(blobs_and_full_paths, key=lambda x: x[1]),
sorted(expected_blobs_and_full_paths, key=lambda x: x[1]),
):
assert (
blob == expected_blob
), "actual blob info does not match expected blob info"
assert (
full_path == expected_full_path
), "actual full path does not match expected full path"
with pytest.raises(
httpx.HTTPError,
match="404 Client Error: Not Found for url: https://api.github.com/repos/owner/repo/git/trees/12345",
):
await reader._recurse_tree("12345")
@pytest.mark.asyncio
async def test__generate_documents():
github_client = MagicMock()
owner = "owner"
repo = "repo"
reader = GithubRepositoryReader(
github_client=github_client,
owner=owner,
repo=repo,
use_parser=False,
verbose=False,
)
blobs_and_paths: List[Tuple[GitTreeResponseModel.GitTreeObject, str]] = [
(
GitTreeResponseModel.GitTreeObject(
type="blob",
path="file1.py",
sha="5678",
mode="100644",
size=1111,
url="https://api.github.com/repos/owner/repo/git/blobs/5678",
),
"file1.py",
),
(
GitTreeResponseModel.GitTreeObject(
type="blob",
path="file2.ts",
sha="1213",
mode="100644",
size=3333,
url="https://api.github.com/repos/owner/repo/git/blobs/1213",
),
"folder1/file2.ts",
),
(
GitTreeResponseModel.GitTreeObject(
type="blob",
path="file3.rs",
sha="1415",
mode="100644",
size=4444,
url="https://api.github.com/repos/owner/repo/git/blobs/1415",
),
"folder1/folder2/file3.rs",
),
(
GitTreeResponseModel.GitTreeObject(
type="blob",
path="file4.cc",
sha="1617",
mode="100644",
size=6666,
url="https://api.github.com/repos/owner/repo/git/blobs/1617",
),
"folder1/folder2/folder3/file4.cc",
),
(
GitTreeResponseModel.GitTreeObject( # this file should not end up in the generated documents since it should fail to decode as utf-8
type="blob",
path="file5.png",
sha="2021",
mode="100644",
size=8888,
url="https://api.github.com/repos/owner/repo/git/blobs/2021",
),
"folder1/folder2/folder3/file5.png",
),
]
async def get_blob_side_effect(owner: str, repo: str, sha: str):
if sha == "5678":
return GitBlobResponseModel(
content="cHJpbnQoJ2hlbGxvIHdvcmxkJyk=",
encoding="base64",
sha="5678",
size=1111,
url="https://api.github.com/repos/owner/repo/git/blobs/5678",
node_id="1234",
)
elif sha == "1213":
return GitBlobResponseModel(
content="Y29uc29sZS5sb2coJ2hlbGxvIHdvcmxkJyk=",
encoding="base64",
sha="1213",
size=3333,
url="https://api.github.com/repos/owner/repo/git/blobs/1213",
node_id="2345",
)
elif sha == "1415":
return GitBlobResponseModel(
content="cHJpbnRsbiEoImhlbGxvIHdvcmxkIik=",
encoding="base64",
sha="1415",
size=4444,
url="https://api.github.com/repos/owner/repo/git/blobs/1415",
node_id="3456",
)
elif sha == "1617":
return GitBlobResponseModel(
content="c3RkOjpjb3V0IDw8ICJoZWxsbyB3b3JsZCIgPDwgc3RkOjplbmRsOw==",
encoding="base64",
sha="1617",
size=6666,
url="https://api.github.com/repos/owner/repo/git/blobs/1617",
node_id="4567",
)
elif sha == "2021":
return GitBlobResponseModel(
content="iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==",
encoding="base64",
sha="2021",
size=8888,
url="https://api.github.com/repos/owner/repo/git/blobs/2021",
node_id="5678",
)
else:
raise httpx.HTTPError(
f"404 Client Error: Not Found for url: https://api.github.com/repos/{owner}/{repo}/git/blobs/{sha}"
)
github_client.get_blob = AsyncMock(side_effect=get_blob_side_effect)
documents = await reader._generate_documents(blobs_and_paths)
assert (
github_client.get_blob.await_count == 5
), "get_blob should be awaited 5 times for each blob"
github_client.get_blob.assert_has_awaits(
[
call(owner, repo, "5678"),
call(owner, repo, "1213"),
call(owner, repo, "1415"),
call(owner, repo, "1617"),
call(owner, repo, "2021"),
]
), "get_blob should be awaited with the correct arguments"
assert (
len(documents) == 4
), "There should be 4 documents generated from the blobs_and_paths"
expected_documents = [
Document(
text="print('hello world')",
extra_info={
"file_path": "file1.py",
"file_name": "file1.py",
},
),
Document(
text="console.log('hello world')",
extra_info={
"file_path": "folder1/file2.ts",
"file_name": "file2.ts",
},
),
Document(
text='println!("hello world")',
extra_info={
"file_path": "folder1/folder2/file3.rs",
"file_name": "file3.rs",
},
),
Document(
text='std::cout << "hello world" << std::endl;',
extra_info={
"file_path": "folder1/folder2/folder3/file4.cc",
"file_name": "file4.cc",
},
),
]
for document, expected_document in zip(
sorted(documents, key=lambda x: x.extra_info["file_path"]),
sorted(expected_documents, key=lambda x: x.extra_info["file_path"]),
):
assert (
document.text == expected_document.text
), "The text of the document should be the decoded content of the blob"
assert (
document.extra_info == expected_document.extra_info
), "The extra_info of the document should be the file_path and file_name"
with pytest.raises(
httpx.HTTPError,
match="404 Client Error: Not Found for url: https://api.github.com/repos/owner/repo/git/blobs/12345",
):
await reader._generate_documents(
[
(
GitTreeResponseModel.GitTreeObject(
type="blob",
path="file1.py",
sha="12345",
mode="100644",
size=1111,
url="https://api.github.com/repos/owner/repo/git/blobs/12345",
),
"file1.py",
)
]
)