Merge pull request #100 from emptycrown/jerry/revert_gh_changes

Revert "Merge pull request #73 from ahmetkca/github-reader-test-and-fix"
This commit is contained in:
Jerry Liu 2023-03-11 19:14:27 -08:00 committed by GitHub
commit 75480e16d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 77 additions and 900 deletions

2
.gitignore vendored
View File

@ -1,4 +1,4 @@
*.egg-info/
.modules
**/__pycache__/
**/__pycache__/

View File

@ -1 +1 @@
"""Init file."""
"""Init file."""

View File

@ -14,7 +14,7 @@ import os
from llama_index import download_loader
download_loader("GithubRepositoryReader")
from llama_index.readers.llamahub_modules.github_repo import GithubRepositoryReader, GithubClient
from modules.github_repo import GithubRepositoryReader, GithubClient
github_client = GithubClient(os.getenv("GITHUB_TOKEN"))
loader = GithubRepositoryReader(
@ -51,7 +51,7 @@ assert (
from llama_index import download_loader
download_loader("GithubRepositoryReader")
from llama_index.readers.llamahub_modules.github_repo import GithubClient, GithubRepositoryReader
from modules.github_repo import GithubClient, GithubRepositoryReader
docs = None
@ -79,5 +79,5 @@ if docs is None:
index = GPTSimpleVectorIndex(docs)
index.query("Explain each LlamaIndex class?")
index.query("Explain each index class?")
```

View File

@ -5,49 +5,26 @@ Retrieves the contents of a Github repository and returns a list of documents.
The documents are either the contents of the files in the repository or
the text extracted from the files using the parser.
"""
import os
import asyncio
import base64
import binascii
import enum
import logging
import os
import pathlib
import tempfile
import enum
import sys
from typing import Any, Callable, List, Optional, Tuple
from llama_index.readers.base import BaseReader
from llama_index.readers.file.base import DEFAULT_FILE_EXTRACTOR
from llama_index.readers.llamahub_modules.github_repo.github_client import (
BaseGithubClient, GitBranchResponseModel, GitCommitResponseModel,
GithubClient, GitTreeResponseModel)
from llama_index.readers.llamahub_modules.github_repo.utils import (
BufferedGitBlobDataIterator, get_file_extension, print_if_verbose)
from llama_index.readers.schema.base import Document
if "pytest" in sys.modules:
from loader_hub.github_repo.github_client import (
BaseGithubClient,
GitBranchResponseModel,
GitCommitResponseModel,
GithubClient,
GitTreeResponseModel,
)
from loader_hub.github_repo.utils import (
BufferedGitBlobDataIterator,
print_if_verbose,
get_file_extension,
)
else:
from llama_index.readers.llamahub_modules.github_repo.github_client import (
BaseGithubClient,
GithubClient,
GitBranchResponseModel,
GitCommitResponseModel,
GitTreeResponseModel,
)
from llama_index.readers.llamahub_modules.github_repo.utils import (
BufferedGitBlobDataIterator,
print_if_verbose,
get_file_extension,
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@ -85,7 +62,7 @@ class GithubRepositoryReader(BaseReader):
github_client: BaseGithubClient,
owner: str,
repo: str,
use_parser: bool = False,
use_parser: bool = True,
verbose: bool = False,
concurrent_requests: int = 5,
filter_directories: Optional[Tuple[List[str], FilterType]] = None,
@ -146,8 +123,6 @@ class GithubRepositoryReader(BaseReader):
:return: True if the tree object should be allowed, False otherwise
"""
if self._filter_directories is None:
return True
filter_directories, filter_type = self._filter_directories
print_if_verbose(
self._verbose,
@ -161,16 +136,17 @@ class GithubRepositoryReader(BaseReader):
or directory.startswith(tree_obj_path)
for directory in filter_directories
)
if filter_type == self.FilterType.INCLUDE:
elif filter_type == self.FilterType.INCLUDE:
return any(
tree_obj_path.startswith(directory)
or directory.startswith(tree_obj_path)
for directory in filter_directories
)
raise ValueError(
f"Unknown filter type: {filter_type}. "
"Please use either 'INCLUDE' or 'EXCLUDE'."
)
else:
raise ValueError(
f"Unknown filter type: {filter_type}. "
"Please use either 'ignore' or 'include'."
)
def _check_filter_file_extensions(self, tree_obj_path: str) -> bool:
"""
@ -180,8 +156,6 @@ class GithubRepositoryReader(BaseReader):
:return: True if the tree object should be allowed, False otherwise
"""
if self._filter_file_extensions is None:
return True
filter_file_extensions, filter_type = self._filter_file_extensions
print_if_verbose(
self._verbose,
@ -190,15 +164,14 @@ class GithubRepositoryReader(BaseReader):
)
if filter_type == self.FilterType.EXCLUDE:
return (
get_file_extension(tree_obj_path) not in filter_file_extensions
)
if filter_type == self.FilterType.INCLUDE:
return get_file_extension(tree_obj_path) not in filter_file_extensions
elif filter_type == self.FilterType.INCLUDE:
return get_file_extension(tree_obj_path) in filter_file_extensions
raise ValueError(
f"Unknown filter type: {filter_type}. "
"Please use either 'INCLUDE' or 'EXCLUDE'."
)
else:
raise ValueError(
f"Unknown filter type: {filter_type}. "
"Please use either 'ignore' or 'include'."
)
def _allow_tree_obj(self, tree_obj_path: str) -> bool:
"""
@ -209,17 +182,13 @@ class GithubRepositoryReader(BaseReader):
:return: True if the tree object should be allowed, False otherwise
"""
is_dir_allowed = True
if self._filter_directories is not None:
is_dir_allowed = self._check_filter_directories(tree_obj_path)
return self._check_filter_directories(tree_obj_path)
is_file_ext_allowed = True
if self._filter_file_extensions is not None:
is_file_ext_allowed = self._check_filter_file_extensions(
tree_obj_path
)
return self._check_filter_file_extensions(tree_obj_path)
return is_dir_allowed and is_file_ext_allowed
return True
def _load_data_from_commit(self, commit_sha: str) -> List[Document]:
"""
@ -231,18 +200,12 @@ class GithubRepositoryReader(BaseReader):
:return: list of documents
"""
commit_response: GitCommitResponseModel = (
self._loop.run_until_complete(
self._github_client.get_commit(
self._owner, self._repo, commit_sha
)
)
commit_response: GitCommitResponseModel = self._loop.run_until_complete(
self._github_client.get_commit(self._owner, self._repo, commit_sha)
)
tree_sha = commit_response.commit.tree.sha
blobs_and_paths = self._loop.run_until_complete(
self._recurse_tree(tree_sha)
)
blobs_and_paths = self._loop.run_until_complete(self._recurse_tree(tree_sha))
print_if_verbose(self._verbose, f"got {len(blobs_and_paths)} blobs")
@ -265,9 +228,7 @@ class GithubRepositoryReader(BaseReader):
)
tree_sha = branch_data.commit.commit.tree.sha
blobs_and_paths = self._loop.run_until_complete(
self._recurse_tree(tree_sha)
)
blobs_and_paths = self._loop.run_until_complete(self._recurse_tree(tree_sha))
print_if_verbose(self._verbose, f"got {len(blobs_and_paths)} blobs")
@ -305,11 +266,7 @@ class GithubRepositoryReader(BaseReader):
raise ValueError("You must specify one of commit or branch.")
async def _recurse_tree(
self,
tree_sha: str,
current_path: str = "",
current_depth: int = 0,
max_depth: int = -1,
self, tree_sha: str, current_path: str = "", current_depth: int = 0
) -> Any:
"""
Recursively get all blob tree objects in a tree.
@ -324,16 +281,9 @@ class GithubRepositoryReader(BaseReader):
:return: list of tuples of
(tree object, file's full path realtive to the root of the repo)
"""
if max_depth != -1 and current_depth > max_depth:
return []
blobs_and_full_paths: List[
Tuple[GitTreeResponseModel.GitTreeObject, str]
] = []
blobs_and_full_paths: List[Tuple[GitTreeResponseModel.GitTreeObject, str]] = []
print_if_verbose(
self._verbose,
"\t" * current_depth + f"current path: {current_path}",
self._verbose, "\t" * current_depth + f"current path: {current_path}"
)
tree_data: GitTreeResponseModel = await self._github_client.get_tree(
@ -345,37 +295,39 @@ class GithubRepositoryReader(BaseReader):
for tree_obj in tree_data.tree:
file_path = os.path.join(current_path, tree_obj.path)
if not self._allow_tree_obj(file_path):
print_if_verbose(
self._verbose,
"\t" * current_depth
+ f"ignoring {tree_obj.path} due to filter",
)
continue
if tree_obj.type == "tree":
print_if_verbose(
self._verbose,
"\t" * current_depth + f"recursing into {tree_obj.path}",
)
if not self._check_filter_directories(file_path):
print_if_verbose(
self._verbose,
"\t" * current_depth
+ f"ignoring directory {tree_obj.path} due to filter",
)
continue
blobs_and_full_paths.extend(
await self._recurse_tree(
tree_obj.sha, file_path, current_depth + 1, max_depth
)
await self._recurse_tree(tree_obj.sha, file_path, current_depth + 1)
)
elif tree_obj.type == "blob":
print_if_verbose(
self._verbose,
"\t" * current_depth + f"found blob {tree_obj.path}",
self._verbose, "\t" * current_depth + f"found blob {tree_obj.path}"
)
if not self._check_filter_file_extensions(file_path):
print_if_verbose(
self._verbose,
"\t" * current_depth
+ f"ignoring file {tree_obj.path} due to filter",
)
continue
blobs_and_full_paths.append((tree_obj, file_path))
return blobs_and_full_paths
async def _generate_documents(
self,
blobs_and_paths: List[Tuple[GitTreeResponseModel.GitTreeObject, str]],
self, blobs_and_paths: List[Tuple[GitTreeResponseModel.GitTreeObject, str]]
) -> List[Document]:
"""
Generate documents from a list of blobs and their full paths.
@ -396,9 +348,7 @@ class GithubRepositoryReader(BaseReader):
documents = []
async for blob_data, full_path in buffered_iterator:
print_if_verbose(
self._verbose, f"generating document for {full_path}"
)
print_if_verbose(self._verbose, f"generating document for {full_path}")
assert (
blob_data.encoding == "base64"
), f"blob encoding {blob_data.encoding} not supported"
@ -454,11 +404,7 @@ class GithubRepositoryReader(BaseReader):
return documents
def _parse_supported_file(
self,
file_path: str,
file_content: bytes,
tree_sha: str,
tree_path: str,
self, file_path: str, file_content: bytes, tree_sha: str, tree_path: str
) -> Optional[Document]:
"""
Parse a file if it is supported by a parser.
@ -492,9 +438,7 @@ class GithubRepositoryReader(BaseReader):
tmpfile.flush()
tmpfile.close()
try:
parsed_file = parser.parse_file(
pathlib.Path(tmpfile.name)
)
parsed_file = parser.parse_file(pathlib.Path(tmpfile.name))
parsed_file = "\n\n".join(parsed_file)
except Exception as e:
print_if_verbose(
@ -536,9 +480,7 @@ if __name__ == "__main__":
return wrapper
github_client = GithubClient(
github_token=os.environ["GITHUB_TOKEN"], verbose=True
)
github_client = GithubClient(github_token=os.environ["GITHUB_TOKEN"], verbose=True)
reader1 = GithubRepositoryReader(
github_client=github_client,
@ -551,16 +493,7 @@ if __name__ == "__main__":
GithubRepositoryReader.FilterType.INCLUDE,
),
filter_file_extensions=(
[
".png",
".jpg",
".jpeg",
".gif",
".svg",
".ico",
"json",
".ipynb",
],
[".png", ".jpg", ".jpeg", ".gif", ".svg", ".ico", "json", ".ipynb"],
GithubRepositoryReader.FilterType.EXCLUDE,
),
)

View File

@ -104,8 +104,7 @@ class GitCommitResponseModel(DataClassJsonMixin):
tree: Tree
commit: Commit
url: str
sha: str
@dataclass
class GitBranchResponseModel(DataClassJsonMixin):
@ -138,14 +137,7 @@ class GitBranchResponseModel(DataClassJsonMixin):
commit: Commit
@dataclass
class Links(DataClassJsonMixin):
self: str
html: str
commit: Commit
name: str
_links: Links
from typing import Protocol

View File

@ -5,23 +5,12 @@ This module contains utility functions for the Github readers.
"""
import asyncio
import os
import sys
import time
from abc import ABC, abstractmethod
from typing import List, Tuple
if "pytest" in sys.modules:
from loader_hub.github_repo.github_client import (
GitBlobResponseModel,
GithubClient,
GitTreeResponseModel,
)
else:
from llama_index.readers.llamahub_modules.github_repo.github_client import (
GitBlobResponseModel,
GithubClient,
GitTreeResponseModel,
)
from gpt_index.readers.github_readers.github_api_client import (
GitBlobResponseModel, GithubClient, GitTreeResponseModel)
def print_if_verbose(verbose: bool, message: str) -> None:
@ -164,8 +153,7 @@ class BufferedGitBlobDataIterator(BufferedAsyncIterator):
if self._verbose:
end_t = time.time()
blob_names_and_sizes = [
(blob.path, blob.size)
for blob, _ in self._blobs_and_paths[start:end]
(blob.path, blob.size) for blob, _ in self._blobs_and_paths[start:end]
]
print(
"Time to get blobs ("
@ -175,7 +163,5 @@ class BufferedGitBlobDataIterator(BufferedAsyncIterator):
self._buffer = [
(result, path)
for result, (_, path) in zip(
results, self._blobs_and_paths[start:end]
)
for result, (_, path) in zip(results, self._blobs_and_paths[start:end])
]

View File

@ -2,10 +2,8 @@
# For testing
pytest==7.2.1
pytest-dotenv==0.5.2
pytest-asyncio
# TODO: remove gpt_index after migration
https://github.com/jerryjliu/gpt_index/archive/master.zip
httpx
llama-index

View File

@ -1,190 +1,19 @@
from llama_index import Document
import httpx
import pytest
import asyncio
import base64
import os
from unittest.mock import MagicMock, AsyncMock, call
import unittest
from typing import List, Tuple
from unittest.mock import AsyncMock, MagicMock
import pytest
# Remove this to test changes to GithubRepositoryReader.
pytest.skip(
"Skip by default due to dependence on network request and github api token.",
allow_module_level=True,
)
pytest.skip("Skip by default due to network request.", allow_module_level=True)
from loader_hub.github_repo.utils import (
BufferedAsyncIterator,
BufferedGitBlobDataIterator,
)
import base64
import os
from loader_hub.github_repo.github_client import (
GithubClient,
GitBlobResponseModel,
GitTreeResponseModel,
)
import pytest
from gpt_index import Document
from loader_hub.github_repo.base import GithubRepositoryReader
## Test BufferedAsyncIterator ##
## and BufferedGitBlobDataIterator ##
class MockGithubClient:
async def get_blob(self, owner, repo, sha):
return f"base64-decoded string blob content {owner}/{repo}/{sha}"
@pytest.mark.asyncio
async def test_buffered_async_iterator():
class TestIterator(BufferedAsyncIterator):
def __init__(self, data: List[Tuple[str, str]], buffer_size: int = 2):
super().__init__(buffer_size)
self._data = data
async def _fill_buffer(self):
del self._buffer[:]
self._buffer = []
start = self._index
end = min(start + self._buffer_size, len(self._data))
if start >= end:
return
self._buffer = self._data[start:end]
data = [
("my-sha-1", "my/path1"),
("my-sha-2", "my/path2"),
("my-sha-3", "my/path3"),
("my-sha-4", "my/path4"),
("my-sha-5", "my/path5"),
("my-sha-6", "my/path6"),
]
iterator = TestIterator(data, buffer_size=2)
assert len(iterator._buffer) == 0
assert iterator._index == 0
assert iterator._buffer_size == 2
assert await iterator.__anext__() == ("my-sha-1", "my/path1")
assert len(iterator._buffer) == 1
assert iterator._index == 1
assert await iterator.__anext__() == ("my-sha-2", "my/path2")
assert len(iterator._buffer) == 0
assert iterator._index == 2
assert await iterator.__anext__() == ("my-sha-3", "my/path3")
assert len(iterator._buffer) == 1
assert iterator._index == 3
assert await iterator.__anext__() == ("my-sha-4", "my/path4")
assert len(iterator._buffer) == 0
assert iterator._index == 4
assert await iterator.__anext__() == ("my-sha-5", "my/path5")
assert len(iterator._buffer) == 1
assert iterator._index == 5
assert await iterator.__anext__() == ("my-sha-6", "my/path6")
assert len(iterator._buffer) == 0
assert iterator._index == 6
with pytest.raises(StopAsyncIteration):
await iterator.__anext__()
@pytest.mark.asyncio
async def test_buffered_git_blob_data_iterator():
github_client = MockGithubClient()
owner = "my-owner"
repo = "my-repo"
loop = asyncio.get_event_loop()
blobs_and_paths = [
(
GitTreeResponseModel.GitTreeObject(
sha="my-sha-1",
path="file1",
mode="100644",
type="blob",
size=123,
url="https://api.github.com/repos/octocat/Hello-World/git/blobs/my-sha-1",
),
"path/file1",
),
(
GitTreeResponseModel.GitTreeObject(
sha="my-sha-2",
path="file2",
mode="100644",
type="blob",
size=321,
url="https://api.github.com/repos/octocat/Hello-World/git/blobs/my-sha-2",
),
"path/file2",
),
(
GitTreeResponseModel.GitTreeObject(
sha="my-sha-3",
path="file3",
mode="100644",
type="blob",
size=456,
url="https://api.github.com/repos/octocat/Hello-World/git/blobs/my-sha-3",
),
"path/to/file3",
),
(
GitTreeResponseModel.GitTreeObject(
sha="my-sha-4",
path="file4",
mode="100644",
type="blob",
size=941,
url="https://api.github.com/repos/octocat/Hello-World/git/blobs/my-sha-4",
),
"path/to/file4",
),
]
it = BufferedGitBlobDataIterator(
blobs_and_paths,
github_client,
owner,
repo,
loop,
buffer_size=3,
verbose=False,
)
assert len(it._buffer) == 0
assert it._index == 0
assert it._buffer_size == 3
assert await it.__anext__() == (
f"base64-decoded string blob content {owner}/{repo}/my-sha-1",
"path/file1",
)
assert len(it._buffer) == 2
assert it._index == 1
assert await it.__anext__() == (
f"base64-decoded string blob content {owner}/{repo}/my-sha-2",
"path/file2",
)
assert len(it._buffer) == 1
assert it._index == 2
assert await it.__anext__() == (
f"base64-decoded string blob content {owner}/{repo}/my-sha-3",
"path/to/file3",
)
assert len(it._buffer) == 0
assert it._index == 3
assert await it.__anext__() == (
f"base64-decoded string blob content {owner}/{repo}/my-sha-4",
"path/to/file4",
)
assert len(it._buffer) == 0
assert it._index == 4
with pytest.raises(StopAsyncIteration):
await it.__anext__()
########################
## GithubClient tests ##
from loader_hub.github_repo import GithubClient, GithubRepositoryReader
@pytest.fixture
@ -210,8 +39,7 @@ async def test_github_client(github_client):
== f"https://api.github.com/repos/{owner}/{repo}/branches/{branch}"
), "Branch self link is incorrect"
assert (
branch_data._links.html
== f"https://github.com/{owner}/{repo}/tree/{branch}"
branch_data._links.html == f"https://github.com/{owner}/{repo}/tree/{branch}"
), "Branch html link is incorrect"
# test get_commit
@ -223,16 +51,12 @@ async def test_github_client(github_client):
), "Commit url is incorrect"
# test get_tree
tree_data = await github_client.get_tree(
owner, repo, commit_data.commit.tree.sha
)
tree_data = await github_client.get_tree(owner, repo, commit_data.commit.tree.sha)
assert (
tree_data.url
== f"https://api.github.com/repos/{owner}/{repo}/git/trees/{commit_data.commit.tree.sha}"
), "Tree url is incorrect"
assert (
tree_data.sha == commit_data.commit.tree.sha
), "Tree sha is incorrect"
assert tree_data.sha == commit_data.commit.tree.sha, "Tree sha is incorrect"
print(tree_data.tree[0].sha)
assert 1 == 1
@ -271,9 +95,7 @@ async def test_github_client(github_client):
][0]
# test get_blob
blob_data = await github_client.get_blob(
owner, repo, test_requirements_txt.sha
)
blob_data = await github_client.get_blob(owner, repo, test_requirements_txt.sha)
assert blob_data.encoding == "base64", "Blob encoding is incorrect"
assert (
blob_data.url
@ -307,557 +129,3 @@ isort==5.11.4
filter(lambda x: x != "", expected_decoded_blob_content.splitlines()),
):
assert dbc[0] == dbc[1], f"{dbc[0]} is not equal to {dbc[1]}"
class TestGithubRepositoryReader(unittest.TestCase):
def setUp(self):
self.github_client = MagicMock()
self.owner = "owner"
self.repo = "repo"
self.reader = GithubRepositoryReader(
self.github_client, self.owner, self.repo
)
def test__check_filter_directories(self):
tree_obj_path = "path/to/some/file.py"
self.reader._filter_directories = (
["path/to"],
GithubRepositoryReader.FilterType.INCLUDE,
)
self.assertTrue(self.reader._check_filter_directories(tree_obj_path))
self.reader._filter_directories = (
["path/to"],
GithubRepositoryReader.FilterType.EXCLUDE,
)
self.assertFalse(self.reader._check_filter_directories(tree_obj_path))
def test__check_filter_file_extensions(self):
tree_obj_path = "path/to/some/file.py"
self.reader._filter_file_extensions = (
[".py"],
GithubRepositoryReader.FilterType.INCLUDE,
)
self.assertTrue(
self.reader._check_filter_file_extensions(tree_obj_path)
)
self.reader._filter_file_extensions = (
[".txt"],
GithubRepositoryReader.FilterType.EXCLUDE,
)
self.assertTrue(
self.reader._check_filter_file_extensions(tree_obj_path)
)
def test__allow_tree_obj(self):
tree_obj_paths = [
"src/file.py",
"src/file.txt",
"src/Path.To.Folder/file1.js",
"src/Path.To.Folder/file2.cpp",
"src/Path.To.Folder/file4.rs",
"src/Path.To.Folder/file5.ts",
"src/Path.To.Folder/file6.h",
"src/Path.To.Folder/file7.c",
"src/Path.To.Folder/file8.java",
"src/dir1/file.js",
"src/assets/file.png",
"src/assets/file.jpg",
"src/assets/file.jpeg",
"src/assets/file.gif",
"src/assets/file.svg",
"src/assets/file.ico",
"src/documents/file.pdf",
"src/documents/file.doc",
"src/documents/file.docx",
"src/documents/file.xls",
"src/documents/file.xlsx",
"src/documents/file.ppt",
"src/documents/file.pptx",
"src/documents/file.odt",
"src/documents/file.ods",
"src/dir2/subdir/file.cpp",
"src/dir2/subdir/file.c",
"src/dir2/subdir/file.h",
"src/dir2/subdir/file.hpp",
"src/dir2/subdir/file.java",
"src/dir2/foo.cc",
"src/dir2/foo.svg",
"src/dir2/subdir/file.go",
"src/sub/folder/loading.svg",
"src/sub/folder/loading.ico",
]
self.reader._filter_directories = (
["src/assets", "src/documents"],
GithubRepositoryReader.FilterType.EXCLUDE,
)
self.reader._filter_file_extensions = (
[".svg", ".ico", ".cpp", ".c", ".h"],
GithubRepositoryReader.FilterType.EXCLUDE,
)
expected_tree_obj_paths = [
"src/file.py",
"src/file.txt",
"src/Path.To.Folder/file1.js",
"src/Path.To.Folder/file4.rs",
"src/Path.To.Folder/file5.ts",
"src/Path.To.Folder/file8.java",
"src/dir1/file.js",
"src/dir2/subdir/file.hpp",
"src/dir2/subdir/file.java",
"src/dir2/foo.cc",
"src/dir2/subdir/file.go",
]
actual_tree_obj_paths = [
tree_obj_path
for tree_obj_path in tree_obj_paths
if self.reader._allow_tree_obj(tree_obj_path)
]
print(f"Expected: {expected_tree_obj_paths}")
print(f"Actual: {actual_tree_obj_paths}")
self.assertCountEqual(
expected_tree_obj_paths, actual_tree_obj_paths
), "Tree object paths are incorrect"
self.reader._filter_directories = (
["src/dir2/subdir", "src/documents", "src/Path.To.Folder"],
GithubRepositoryReader.FilterType.INCLUDE,
)
self.reader._filter_file_extensions = (
[".png", ".svg", ".ico", "jpg", ".java"],
GithubRepositoryReader.FilterType.EXCLUDE,
)
expected_tree_obj_paths = [
"src/Path.To.Folder/file1.js",
"src/Path.To.Folder/file2.cpp",
"src/Path.To.Folder/file4.rs",
"src/Path.To.Folder/file5.ts",
"src/Path.To.Folder/file6.h",
"src/Path.To.Folder/file7.c",
"src/documents/file.pdf",
"src/documents/file.doc",
"src/documents/file.docx",
"src/documents/file.xls",
"src/documents/file.xlsx",
"src/documents/file.ppt",
"src/documents/file.pptx",
"src/documents/file.odt",
"src/documents/file.ods",
"src/dir2/subdir/file.cpp",
"src/dir2/subdir/file.c",
"src/dir2/subdir/file.h",
"src/dir2/subdir/file.hpp",
"src/dir2/subdir/file.go",
]
actual_tree_obj_paths = [
tree_obj_path
for tree_obj_path in tree_obj_paths
if self.reader._allow_tree_obj(tree_obj_path)
]
print(f"Expected: {expected_tree_obj_paths}")
print(f"Actual: {actual_tree_obj_paths}")
self.assertCountEqual(
expected_tree_obj_paths, actual_tree_obj_paths
), "Tree object paths are incorrect"
## Test GithubRepositoryReader's _recurse_tree method
@pytest.mark.asyncio
async def test__recurse_tree():
github_client = MagicMock()
owner = "owner"
repo = "repo"
reader = GithubRepositoryReader(github_client, owner, repo)
# return value for the first call to get_tree (the root tree)
tree_sha = "1234"
tree_data = GitTreeResponseModel(
sha=tree_sha,
tree=[
GitTreeResponseModel.GitTreeObject(
type="blob",
path="file1.py",
sha="5678",
mode="100644",
size=1111,
url="https://api.github.com/repos/owner/repo/git/blobs/5678",
),
GitTreeResponseModel.GitTreeObject(
type="tree",
path="folder1",
sha="91011",
mode="040000",
size=None,
url="https://api.github.com/repos/owner/repo/git/blobs/91011",
),
GitTreeResponseModel.GitTreeObject(
type="blob",
path="file2.py",
sha="1213",
mode="100644",
size=3333,
url="https://api.github.com/repos/owner/repo/git/blobs/1213",
),
GitTreeResponseModel.GitTreeObject(
type="blob",
path="file3.py",
sha="1415",
mode="100644",
size=4444,
url="https://api.github.com/repos/owner/repo/git/blobs/1415",
),
],
truncated=False,
url="https://api.github.com/repos/owner/repo/git/trees/1234",
)
def get_tree_side_effect(owner, repo, sha):
if sha == tree_sha:
return tree_data
elif sha == "91011":
# return value for the second call to get_tree (the tree for folder1)
return GitTreeResponseModel(
tree=[
GitTreeResponseModel.GitTreeObject(
type="blob",
path="file4.py",
sha="1617",
mode="100644",
size=6666,
url="https://api.github.com/repos/owner/repo/git/blobs/1617",
),
GitTreeResponseModel.GitTreeObject(
type="tree",
path="folder3",
sha="1819",
mode="040000",
size=None,
url="https://api.github.com/repos/owner/repo/git/blobs/1819",
),
],
sha="91011",
truncated=False,
url="https://api.github.com/repos/owner/repo/git/trees/91011",
)
elif sha == "1819":
# return value for the third call to get_tree (the tree for folder3)
return GitTreeResponseModel(
tree=[
GitTreeResponseModel.GitTreeObject(
type="blob",
path="file5.py",
sha="2021",
mode="100644",
size=8888,
url="https://api.github.com/repos/owner/repo/git/blobs/2021",
),
],
sha="1819",
truncated=False,
url="https://api.github.com/repos/owner/repo/git/trees/1819",
)
else:
raise httpx.HTTPError(
f"404 Client Error: Not Found for url: https://api.github.com/repos/{owner}/{repo}/git/trees/{sha}"
)
github_client.get_tree = AsyncMock(side_effect=get_tree_side_effect)
expected_blobs_and_full_paths = [
(
GitTreeResponseModel.GitTreeObject(
type="blob",
path="file1.py",
sha="5678",
mode="100644",
size=1111,
url="https://api.github.com/repos/owner/repo/git/blobs/5678",
),
"file1.py",
),
(
GitTreeResponseModel.GitTreeObject(
type="blob",
path="file2.py",
sha="1213",
mode="100644",
size=3333,
url="https://api.github.com/repos/owner/repo/git/blobs/1213",
),
"file2.py",
),
(
GitTreeResponseModel.GitTreeObject(
type="blob",
path="file4.py",
sha="1617",
mode="100644",
size=6666,
url="https://api.github.com/repos/owner/repo/git/blobs/1617",
),
"folder1/file4.py",
),
(
GitTreeResponseModel.GitTreeObject(
type="blob",
path="file5.py",
sha="2021",
mode="100644",
size=8888,
url="https://api.github.com/repos/owner/repo/git/blobs/2021",
),
"folder1/folder3/file5.py",
),
(
GitTreeResponseModel.GitTreeObject(
type="blob",
path="file3.py",
sha="1415",
mode="100644",
size=4444,
url="https://api.github.com/repos/owner/repo/git/blobs/1415",
),
"file3.py",
),
]
blobs_and_full_paths = await reader._recurse_tree(tree_sha)
# make sure get_tree was called the expected number of times
assert (
github_client.get_tree.call_count == 3
), "There should be only 3 calls to get_tree (one for the root tree, and one for each subfolder folder1 and folder3)"
# sort the expected and actual results by full path so we can compare them
for (blob, full_path), (expected_blob, expected_full_path) in zip(
sorted(blobs_and_full_paths, key=lambda x: x[1]),
sorted(expected_blobs_and_full_paths, key=lambda x: x[1]),
):
assert (
blob == expected_blob
), "actual blob info does not match expected blob info"
assert (
full_path == expected_full_path
), "actual full path does not match expected full path"
with pytest.raises(
httpx.HTTPError,
match="404 Client Error: Not Found for url: https://api.github.com/repos/owner/repo/git/trees/12345",
):
await reader._recurse_tree("12345")
@pytest.mark.asyncio
async def test__generate_documents():
github_client = MagicMock()
owner = "owner"
repo = "repo"
reader = GithubRepositoryReader(
github_client=github_client,
owner=owner,
repo=repo,
use_parser=False,
verbose=False,
)
blobs_and_paths: List[Tuple[GitTreeResponseModel.GitTreeObject, str]] = [
(
GitTreeResponseModel.GitTreeObject(
type="blob",
path="file1.py",
sha="5678",
mode="100644",
size=1111,
url="https://api.github.com/repos/owner/repo/git/blobs/5678",
),
"file1.py",
),
(
GitTreeResponseModel.GitTreeObject(
type="blob",
path="file2.ts",
sha="1213",
mode="100644",
size=3333,
url="https://api.github.com/repos/owner/repo/git/blobs/1213",
),
"folder1/file2.ts",
),
(
GitTreeResponseModel.GitTreeObject(
type="blob",
path="file3.rs",
sha="1415",
mode="100644",
size=4444,
url="https://api.github.com/repos/owner/repo/git/blobs/1415",
),
"folder1/folder2/file3.rs",
),
(
GitTreeResponseModel.GitTreeObject(
type="blob",
path="file4.cc",
sha="1617",
mode="100644",
size=6666,
url="https://api.github.com/repos/owner/repo/git/blobs/1617",
),
"folder1/folder2/folder3/file4.cc",
),
(
GitTreeResponseModel.GitTreeObject( # this file should not end up in the generated documents since it should fail to decode as utf-8
type="blob",
path="file5.png",
sha="2021",
mode="100644",
size=8888,
url="https://api.github.com/repos/owner/repo/git/blobs/2021",
),
"folder1/folder2/folder3/file5.png",
),
]
async def get_blob_side_effect(owner: str, repo: str, sha: str):
if sha == "5678":
return GitBlobResponseModel(
content="cHJpbnQoJ2hlbGxvIHdvcmxkJyk=",
encoding="base64",
sha="5678",
size=1111,
url="https://api.github.com/repos/owner/repo/git/blobs/5678",
node_id="1234",
)
elif sha == "1213":
return GitBlobResponseModel(
content="Y29uc29sZS5sb2coJ2hlbGxvIHdvcmxkJyk=",
encoding="base64",
sha="1213",
size=3333,
url="https://api.github.com/repos/owner/repo/git/blobs/1213",
node_id="2345",
)
elif sha == "1415":
return GitBlobResponseModel(
content="cHJpbnRsbiEoImhlbGxvIHdvcmxkIik=",
encoding="base64",
sha="1415",
size=4444,
url="https://api.github.com/repos/owner/repo/git/blobs/1415",
node_id="3456",
)
elif sha == "1617":
return GitBlobResponseModel(
content="c3RkOjpjb3V0IDw8ICJoZWxsbyB3b3JsZCIgPDwgc3RkOjplbmRsOw==",
encoding="base64",
sha="1617",
size=6666,
url="https://api.github.com/repos/owner/repo/git/blobs/1617",
node_id="4567",
)
elif sha == "2021":
return GitBlobResponseModel(
content="iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==",
encoding="base64",
sha="2021",
size=8888,
url="https://api.github.com/repos/owner/repo/git/blobs/2021",
node_id="5678",
)
else:
raise httpx.HTTPError(
f"404 Client Error: Not Found for url: https://api.github.com/repos/{owner}/{repo}/git/blobs/{sha}"
)
github_client.get_blob = AsyncMock(side_effect=get_blob_side_effect)
documents = await reader._generate_documents(blobs_and_paths)
assert (
github_client.get_blob.await_count == 5
), "get_blob should be awaited 5 times for each blob"
github_client.get_blob.assert_has_awaits(
[
call(owner, repo, "5678"),
call(owner, repo, "1213"),
call(owner, repo, "1415"),
call(owner, repo, "1617"),
call(owner, repo, "2021"),
]
), "get_blob should be awaited with the correct arguments"
assert (
len(documents) == 4
), "There should be 4 documents generated from the blobs_and_paths"
expected_documents = [
Document(
text="print('hello world')",
extra_info={
"file_path": "file1.py",
"file_name": "file1.py",
},
),
Document(
text="console.log('hello world')",
extra_info={
"file_path": "folder1/file2.ts",
"file_name": "file2.ts",
},
),
Document(
text='println!("hello world")',
extra_info={
"file_path": "folder1/folder2/file3.rs",
"file_name": "file3.rs",
},
),
Document(
text='std::cout << "hello world" << std::endl;',
extra_info={
"file_path": "folder1/folder2/folder3/file4.cc",
"file_name": "file4.cc",
},
),
]
for document, expected_document in zip(
sorted(documents, key=lambda x: x.extra_info["file_path"]),
sorted(expected_documents, key=lambda x: x.extra_info["file_path"]),
):
assert (
document.text == expected_document.text
), "The text of the document should be the decoded content of the blob"
assert (
document.extra_info == expected_document.extra_info
), "The extra_info of the document should be the file_path and file_name"
with pytest.raises(
httpx.HTTPError,
match="404 Client Error: Not Found for url: https://api.github.com/repos/owner/repo/git/blobs/12345",
):
await reader._generate_documents(
[
(
GitTreeResponseModel.GitTreeObject(
type="blob",
path="file1.py",
sha="12345",
mode="100644",
size=1111,
url="https://api.github.com/repos/owner/repo/git/blobs/12345",
),
"file1.py",
)
]
)