mirror of
https://github.com/run-llama/llama-hub.git
synced 2025-11-03 11:20:39 +00:00
Merge pull request #100 from emptycrown/jerry/revert_gh_changes
Revert "Merge pull request #73 from ahmetkca/github-reader-test-and-fix"
This commit is contained in:
commit
75480e16d0
2
.gitignore
vendored
2
.gitignore
vendored
@ -1,4 +1,4 @@
|
||||
*.egg-info/
|
||||
.modules
|
||||
|
||||
**/__pycache__/
|
||||
**/__pycache__/
|
||||
@ -1 +1 @@
|
||||
"""Init file."""
|
||||
"""Init file."""
|
||||
|
||||
@ -14,7 +14,7 @@ import os
|
||||
from llama_index import download_loader
|
||||
download_loader("GithubRepositoryReader")
|
||||
|
||||
from llama_index.readers.llamahub_modules.github_repo import GithubRepositoryReader, GithubClient
|
||||
from modules.github_repo import GithubRepositoryReader, GithubClient
|
||||
|
||||
github_client = GithubClient(os.getenv("GITHUB_TOKEN"))
|
||||
loader = GithubRepositoryReader(
|
||||
@ -51,7 +51,7 @@ assert (
|
||||
from llama_index import download_loader
|
||||
download_loader("GithubRepositoryReader")
|
||||
|
||||
from llama_index.readers.llamahub_modules.github_repo import GithubClient, GithubRepositoryReader
|
||||
from modules.github_repo import GithubClient, GithubRepositoryReader
|
||||
|
||||
docs = None
|
||||
|
||||
@ -79,5 +79,5 @@ if docs is None:
|
||||
|
||||
index = GPTSimpleVectorIndex(docs)
|
||||
|
||||
index.query("Explain each LlamaIndex class?")
|
||||
index.query("Explain each index class?")
|
||||
```
|
||||
|
||||
@ -5,49 +5,26 @@ Retrieves the contents of a Github repository and returns a list of documents.
|
||||
The documents are either the contents of the files in the repository or
|
||||
the text extracted from the files using the parser.
|
||||
"""
|
||||
import os
|
||||
import asyncio
|
||||
import base64
|
||||
import binascii
|
||||
import enum
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import tempfile
|
||||
import enum
|
||||
import sys
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
from llama_index.readers.base import BaseReader
|
||||
from llama_index.readers.file.base import DEFAULT_FILE_EXTRACTOR
|
||||
from llama_index.readers.llamahub_modules.github_repo.github_client import (
|
||||
BaseGithubClient, GitBranchResponseModel, GitCommitResponseModel,
|
||||
GithubClient, GitTreeResponseModel)
|
||||
from llama_index.readers.llamahub_modules.github_repo.utils import (
|
||||
BufferedGitBlobDataIterator, get_file_extension, print_if_verbose)
|
||||
from llama_index.readers.schema.base import Document
|
||||
|
||||
|
||||
if "pytest" in sys.modules:
|
||||
from loader_hub.github_repo.github_client import (
|
||||
BaseGithubClient,
|
||||
GitBranchResponseModel,
|
||||
GitCommitResponseModel,
|
||||
GithubClient,
|
||||
GitTreeResponseModel,
|
||||
)
|
||||
from loader_hub.github_repo.utils import (
|
||||
BufferedGitBlobDataIterator,
|
||||
print_if_verbose,
|
||||
get_file_extension,
|
||||
)
|
||||
else:
|
||||
from llama_index.readers.llamahub_modules.github_repo.github_client import (
|
||||
BaseGithubClient,
|
||||
GithubClient,
|
||||
GitBranchResponseModel,
|
||||
GitCommitResponseModel,
|
||||
GitTreeResponseModel,
|
||||
)
|
||||
from llama_index.readers.llamahub_modules.github_repo.utils import (
|
||||
BufferedGitBlobDataIterator,
|
||||
print_if_verbose,
|
||||
get_file_extension,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -85,7 +62,7 @@ class GithubRepositoryReader(BaseReader):
|
||||
github_client: BaseGithubClient,
|
||||
owner: str,
|
||||
repo: str,
|
||||
use_parser: bool = False,
|
||||
use_parser: bool = True,
|
||||
verbose: bool = False,
|
||||
concurrent_requests: int = 5,
|
||||
filter_directories: Optional[Tuple[List[str], FilterType]] = None,
|
||||
@ -146,8 +123,6 @@ class GithubRepositoryReader(BaseReader):
|
||||
|
||||
:return: True if the tree object should be allowed, False otherwise
|
||||
"""
|
||||
if self._filter_directories is None:
|
||||
return True
|
||||
filter_directories, filter_type = self._filter_directories
|
||||
print_if_verbose(
|
||||
self._verbose,
|
||||
@ -161,16 +136,17 @@ class GithubRepositoryReader(BaseReader):
|
||||
or directory.startswith(tree_obj_path)
|
||||
for directory in filter_directories
|
||||
)
|
||||
if filter_type == self.FilterType.INCLUDE:
|
||||
elif filter_type == self.FilterType.INCLUDE:
|
||||
return any(
|
||||
tree_obj_path.startswith(directory)
|
||||
or directory.startswith(tree_obj_path)
|
||||
for directory in filter_directories
|
||||
)
|
||||
raise ValueError(
|
||||
f"Unknown filter type: {filter_type}. "
|
||||
"Please use either 'INCLUDE' or 'EXCLUDE'."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown filter type: {filter_type}. "
|
||||
"Please use either 'ignore' or 'include'."
|
||||
)
|
||||
|
||||
def _check_filter_file_extensions(self, tree_obj_path: str) -> bool:
|
||||
"""
|
||||
@ -180,8 +156,6 @@ class GithubRepositoryReader(BaseReader):
|
||||
|
||||
:return: True if the tree object should be allowed, False otherwise
|
||||
"""
|
||||
if self._filter_file_extensions is None:
|
||||
return True
|
||||
filter_file_extensions, filter_type = self._filter_file_extensions
|
||||
print_if_verbose(
|
||||
self._verbose,
|
||||
@ -190,15 +164,14 @@ class GithubRepositoryReader(BaseReader):
|
||||
)
|
||||
|
||||
if filter_type == self.FilterType.EXCLUDE:
|
||||
return (
|
||||
get_file_extension(tree_obj_path) not in filter_file_extensions
|
||||
)
|
||||
if filter_type == self.FilterType.INCLUDE:
|
||||
return get_file_extension(tree_obj_path) not in filter_file_extensions
|
||||
elif filter_type == self.FilterType.INCLUDE:
|
||||
return get_file_extension(tree_obj_path) in filter_file_extensions
|
||||
raise ValueError(
|
||||
f"Unknown filter type: {filter_type}. "
|
||||
"Please use either 'INCLUDE' or 'EXCLUDE'."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown filter type: {filter_type}. "
|
||||
"Please use either 'ignore' or 'include'."
|
||||
)
|
||||
|
||||
def _allow_tree_obj(self, tree_obj_path: str) -> bool:
|
||||
"""
|
||||
@ -209,17 +182,13 @@ class GithubRepositoryReader(BaseReader):
|
||||
:return: True if the tree object should be allowed, False otherwise
|
||||
|
||||
"""
|
||||
is_dir_allowed = True
|
||||
if self._filter_directories is not None:
|
||||
is_dir_allowed = self._check_filter_directories(tree_obj_path)
|
||||
return self._check_filter_directories(tree_obj_path)
|
||||
|
||||
is_file_ext_allowed = True
|
||||
if self._filter_file_extensions is not None:
|
||||
is_file_ext_allowed = self._check_filter_file_extensions(
|
||||
tree_obj_path
|
||||
)
|
||||
return self._check_filter_file_extensions(tree_obj_path)
|
||||
|
||||
return is_dir_allowed and is_file_ext_allowed
|
||||
return True
|
||||
|
||||
def _load_data_from_commit(self, commit_sha: str) -> List[Document]:
|
||||
"""
|
||||
@ -231,18 +200,12 @@ class GithubRepositoryReader(BaseReader):
|
||||
|
||||
:return: list of documents
|
||||
"""
|
||||
commit_response: GitCommitResponseModel = (
|
||||
self._loop.run_until_complete(
|
||||
self._github_client.get_commit(
|
||||
self._owner, self._repo, commit_sha
|
||||
)
|
||||
)
|
||||
commit_response: GitCommitResponseModel = self._loop.run_until_complete(
|
||||
self._github_client.get_commit(self._owner, self._repo, commit_sha)
|
||||
)
|
||||
|
||||
tree_sha = commit_response.commit.tree.sha
|
||||
blobs_and_paths = self._loop.run_until_complete(
|
||||
self._recurse_tree(tree_sha)
|
||||
)
|
||||
blobs_and_paths = self._loop.run_until_complete(self._recurse_tree(tree_sha))
|
||||
|
||||
print_if_verbose(self._verbose, f"got {len(blobs_and_paths)} blobs")
|
||||
|
||||
@ -265,9 +228,7 @@ class GithubRepositoryReader(BaseReader):
|
||||
)
|
||||
|
||||
tree_sha = branch_data.commit.commit.tree.sha
|
||||
blobs_and_paths = self._loop.run_until_complete(
|
||||
self._recurse_tree(tree_sha)
|
||||
)
|
||||
blobs_and_paths = self._loop.run_until_complete(self._recurse_tree(tree_sha))
|
||||
|
||||
print_if_verbose(self._verbose, f"got {len(blobs_and_paths)} blobs")
|
||||
|
||||
@ -305,11 +266,7 @@ class GithubRepositoryReader(BaseReader):
|
||||
raise ValueError("You must specify one of commit or branch.")
|
||||
|
||||
async def _recurse_tree(
|
||||
self,
|
||||
tree_sha: str,
|
||||
current_path: str = "",
|
||||
current_depth: int = 0,
|
||||
max_depth: int = -1,
|
||||
self, tree_sha: str, current_path: str = "", current_depth: int = 0
|
||||
) -> Any:
|
||||
"""
|
||||
Recursively get all blob tree objects in a tree.
|
||||
@ -324,16 +281,9 @@ class GithubRepositoryReader(BaseReader):
|
||||
:return: list of tuples of
|
||||
(tree object, file's full path realtive to the root of the repo)
|
||||
"""
|
||||
|
||||
if max_depth != -1 and current_depth > max_depth:
|
||||
return []
|
||||
|
||||
blobs_and_full_paths: List[
|
||||
Tuple[GitTreeResponseModel.GitTreeObject, str]
|
||||
] = []
|
||||
blobs_and_full_paths: List[Tuple[GitTreeResponseModel.GitTreeObject, str]] = []
|
||||
print_if_verbose(
|
||||
self._verbose,
|
||||
"\t" * current_depth + f"current path: {current_path}",
|
||||
self._verbose, "\t" * current_depth + f"current path: {current_path}"
|
||||
)
|
||||
|
||||
tree_data: GitTreeResponseModel = await self._github_client.get_tree(
|
||||
@ -345,37 +295,39 @@ class GithubRepositoryReader(BaseReader):
|
||||
for tree_obj in tree_data.tree:
|
||||
file_path = os.path.join(current_path, tree_obj.path)
|
||||
|
||||
if not self._allow_tree_obj(file_path):
|
||||
print_if_verbose(
|
||||
self._verbose,
|
||||
"\t" * current_depth
|
||||
+ f"ignoring {tree_obj.path} due to filter",
|
||||
)
|
||||
continue
|
||||
|
||||
if tree_obj.type == "tree":
|
||||
print_if_verbose(
|
||||
self._verbose,
|
||||
"\t" * current_depth + f"recursing into {tree_obj.path}",
|
||||
)
|
||||
if not self._check_filter_directories(file_path):
|
||||
print_if_verbose(
|
||||
self._verbose,
|
||||
"\t" * current_depth
|
||||
+ f"ignoring directory {tree_obj.path} due to filter",
|
||||
)
|
||||
continue
|
||||
|
||||
blobs_and_full_paths.extend(
|
||||
await self._recurse_tree(
|
||||
tree_obj.sha, file_path, current_depth + 1, max_depth
|
||||
)
|
||||
await self._recurse_tree(tree_obj.sha, file_path, current_depth + 1)
|
||||
)
|
||||
elif tree_obj.type == "blob":
|
||||
print_if_verbose(
|
||||
self._verbose,
|
||||
"\t" * current_depth + f"found blob {tree_obj.path}",
|
||||
self._verbose, "\t" * current_depth + f"found blob {tree_obj.path}"
|
||||
)
|
||||
if not self._check_filter_file_extensions(file_path):
|
||||
print_if_verbose(
|
||||
self._verbose,
|
||||
"\t" * current_depth
|
||||
+ f"ignoring file {tree_obj.path} due to filter",
|
||||
)
|
||||
continue
|
||||
|
||||
blobs_and_full_paths.append((tree_obj, file_path))
|
||||
return blobs_and_full_paths
|
||||
|
||||
async def _generate_documents(
|
||||
self,
|
||||
blobs_and_paths: List[Tuple[GitTreeResponseModel.GitTreeObject, str]],
|
||||
self, blobs_and_paths: List[Tuple[GitTreeResponseModel.GitTreeObject, str]]
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Generate documents from a list of blobs and their full paths.
|
||||
@ -396,9 +348,7 @@ class GithubRepositoryReader(BaseReader):
|
||||
|
||||
documents = []
|
||||
async for blob_data, full_path in buffered_iterator:
|
||||
print_if_verbose(
|
||||
self._verbose, f"generating document for {full_path}"
|
||||
)
|
||||
print_if_verbose(self._verbose, f"generating document for {full_path}")
|
||||
assert (
|
||||
blob_data.encoding == "base64"
|
||||
), f"blob encoding {blob_data.encoding} not supported"
|
||||
@ -454,11 +404,7 @@ class GithubRepositoryReader(BaseReader):
|
||||
return documents
|
||||
|
||||
def _parse_supported_file(
|
||||
self,
|
||||
file_path: str,
|
||||
file_content: bytes,
|
||||
tree_sha: str,
|
||||
tree_path: str,
|
||||
self, file_path: str, file_content: bytes, tree_sha: str, tree_path: str
|
||||
) -> Optional[Document]:
|
||||
"""
|
||||
Parse a file if it is supported by a parser.
|
||||
@ -492,9 +438,7 @@ class GithubRepositoryReader(BaseReader):
|
||||
tmpfile.flush()
|
||||
tmpfile.close()
|
||||
try:
|
||||
parsed_file = parser.parse_file(
|
||||
pathlib.Path(tmpfile.name)
|
||||
)
|
||||
parsed_file = parser.parse_file(pathlib.Path(tmpfile.name))
|
||||
parsed_file = "\n\n".join(parsed_file)
|
||||
except Exception as e:
|
||||
print_if_verbose(
|
||||
@ -536,9 +480,7 @@ if __name__ == "__main__":
|
||||
|
||||
return wrapper
|
||||
|
||||
github_client = GithubClient(
|
||||
github_token=os.environ["GITHUB_TOKEN"], verbose=True
|
||||
)
|
||||
github_client = GithubClient(github_token=os.environ["GITHUB_TOKEN"], verbose=True)
|
||||
|
||||
reader1 = GithubRepositoryReader(
|
||||
github_client=github_client,
|
||||
@ -551,16 +493,7 @@ if __name__ == "__main__":
|
||||
GithubRepositoryReader.FilterType.INCLUDE,
|
||||
),
|
||||
filter_file_extensions=(
|
||||
[
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".gif",
|
||||
".svg",
|
||||
".ico",
|
||||
"json",
|
||||
".ipynb",
|
||||
],
|
||||
[".png", ".jpg", ".jpeg", ".gif", ".svg", ".ico", "json", ".ipynb"],
|
||||
GithubRepositoryReader.FilterType.EXCLUDE,
|
||||
),
|
||||
)
|
||||
|
||||
@ -104,8 +104,7 @@ class GitCommitResponseModel(DataClassJsonMixin):
|
||||
tree: Tree
|
||||
|
||||
commit: Commit
|
||||
url: str
|
||||
sha: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class GitBranchResponseModel(DataClassJsonMixin):
|
||||
@ -138,14 +137,7 @@ class GitBranchResponseModel(DataClassJsonMixin):
|
||||
|
||||
commit: Commit
|
||||
|
||||
@dataclass
|
||||
class Links(DataClassJsonMixin):
|
||||
self: str
|
||||
html: str
|
||||
|
||||
commit: Commit
|
||||
name: str
|
||||
_links: Links
|
||||
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
@ -5,23 +5,12 @@ This module contains utility functions for the Github readers.
|
||||
"""
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple
|
||||
|
||||
if "pytest" in sys.modules:
|
||||
from loader_hub.github_repo.github_client import (
|
||||
GitBlobResponseModel,
|
||||
GithubClient,
|
||||
GitTreeResponseModel,
|
||||
)
|
||||
else:
|
||||
from llama_index.readers.llamahub_modules.github_repo.github_client import (
|
||||
GitBlobResponseModel,
|
||||
GithubClient,
|
||||
GitTreeResponseModel,
|
||||
)
|
||||
from gpt_index.readers.github_readers.github_api_client import (
|
||||
GitBlobResponseModel, GithubClient, GitTreeResponseModel)
|
||||
|
||||
|
||||
def print_if_verbose(verbose: bool, message: str) -> None:
|
||||
@ -164,8 +153,7 @@ class BufferedGitBlobDataIterator(BufferedAsyncIterator):
|
||||
if self._verbose:
|
||||
end_t = time.time()
|
||||
blob_names_and_sizes = [
|
||||
(blob.path, blob.size)
|
||||
for blob, _ in self._blobs_and_paths[start:end]
|
||||
(blob.path, blob.size) for blob, _ in self._blobs_and_paths[start:end]
|
||||
]
|
||||
print(
|
||||
"Time to get blobs ("
|
||||
@ -175,7 +163,5 @@ class BufferedGitBlobDataIterator(BufferedAsyncIterator):
|
||||
|
||||
self._buffer = [
|
||||
(result, path)
|
||||
for result, (_, path) in zip(
|
||||
results, self._blobs_and_paths[start:end]
|
||||
)
|
||||
for result, (_, path) in zip(results, self._blobs_and_paths[start:end])
|
||||
]
|
||||
|
||||
@ -2,10 +2,8 @@
|
||||
# For testing
|
||||
pytest==7.2.1
|
||||
pytest-dotenv==0.5.2
|
||||
pytest-asyncio
|
||||
# TODO: remove gpt_index after migration
|
||||
https://github.com/jerryjliu/gpt_index/archive/master.zip
|
||||
httpx
|
||||
|
||||
llama-index
|
||||
|
||||
|
||||
@ -1,190 +1,19 @@
|
||||
from llama_index import Document
|
||||
import httpx
|
||||
import pytest
|
||||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
from unittest.mock import MagicMock, AsyncMock, call
|
||||
import unittest
|
||||
from typing import List, Tuple
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
# Remove this to test changes to GithubRepositoryReader.
|
||||
pytest.skip(
|
||||
"Skip by default due to dependence on network request and github api token.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
pytest.skip("Skip by default due to network request.", allow_module_level=True)
|
||||
|
||||
from loader_hub.github_repo.utils import (
|
||||
BufferedAsyncIterator,
|
||||
BufferedGitBlobDataIterator,
|
||||
)
|
||||
import base64
|
||||
import os
|
||||
|
||||
from loader_hub.github_repo.github_client import (
|
||||
GithubClient,
|
||||
GitBlobResponseModel,
|
||||
GitTreeResponseModel,
|
||||
)
|
||||
import pytest
|
||||
from gpt_index import Document
|
||||
|
||||
from loader_hub.github_repo.base import GithubRepositoryReader
|
||||
|
||||
|
||||
## Test BufferedAsyncIterator ##
|
||||
## and BufferedGitBlobDataIterator ##
|
||||
|
||||
|
||||
class MockGithubClient:
|
||||
async def get_blob(self, owner, repo, sha):
|
||||
return f"base64-decoded string blob content {owner}/{repo}/{sha}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_buffered_async_iterator():
|
||||
class TestIterator(BufferedAsyncIterator):
|
||||
def __init__(self, data: List[Tuple[str, str]], buffer_size: int = 2):
|
||||
super().__init__(buffer_size)
|
||||
self._data = data
|
||||
|
||||
async def _fill_buffer(self):
|
||||
del self._buffer[:]
|
||||
self._buffer = []
|
||||
start = self._index
|
||||
end = min(start + self._buffer_size, len(self._data))
|
||||
|
||||
if start >= end:
|
||||
return
|
||||
|
||||
self._buffer = self._data[start:end]
|
||||
|
||||
data = [
|
||||
("my-sha-1", "my/path1"),
|
||||
("my-sha-2", "my/path2"),
|
||||
("my-sha-3", "my/path3"),
|
||||
("my-sha-4", "my/path4"),
|
||||
("my-sha-5", "my/path5"),
|
||||
("my-sha-6", "my/path6"),
|
||||
]
|
||||
iterator = TestIterator(data, buffer_size=2)
|
||||
assert len(iterator._buffer) == 0
|
||||
assert iterator._index == 0
|
||||
assert iterator._buffer_size == 2
|
||||
assert await iterator.__anext__() == ("my-sha-1", "my/path1")
|
||||
assert len(iterator._buffer) == 1
|
||||
assert iterator._index == 1
|
||||
assert await iterator.__anext__() == ("my-sha-2", "my/path2")
|
||||
assert len(iterator._buffer) == 0
|
||||
assert iterator._index == 2
|
||||
assert await iterator.__anext__() == ("my-sha-3", "my/path3")
|
||||
assert len(iterator._buffer) == 1
|
||||
assert iterator._index == 3
|
||||
assert await iterator.__anext__() == ("my-sha-4", "my/path4")
|
||||
assert len(iterator._buffer) == 0
|
||||
assert iterator._index == 4
|
||||
assert await iterator.__anext__() == ("my-sha-5", "my/path5")
|
||||
assert len(iterator._buffer) == 1
|
||||
assert iterator._index == 5
|
||||
assert await iterator.__anext__() == ("my-sha-6", "my/path6")
|
||||
assert len(iterator._buffer) == 0
|
||||
assert iterator._index == 6
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await iterator.__anext__()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_buffered_git_blob_data_iterator():
|
||||
github_client = MockGithubClient()
|
||||
owner = "my-owner"
|
||||
repo = "my-repo"
|
||||
loop = asyncio.get_event_loop()
|
||||
blobs_and_paths = [
|
||||
(
|
||||
GitTreeResponseModel.GitTreeObject(
|
||||
sha="my-sha-1",
|
||||
path="file1",
|
||||
mode="100644",
|
||||
type="blob",
|
||||
size=123,
|
||||
url="https://api.github.com/repos/octocat/Hello-World/git/blobs/my-sha-1",
|
||||
),
|
||||
"path/file1",
|
||||
),
|
||||
(
|
||||
GitTreeResponseModel.GitTreeObject(
|
||||
sha="my-sha-2",
|
||||
path="file2",
|
||||
mode="100644",
|
||||
type="blob",
|
||||
size=321,
|
||||
url="https://api.github.com/repos/octocat/Hello-World/git/blobs/my-sha-2",
|
||||
),
|
||||
"path/file2",
|
||||
),
|
||||
(
|
||||
GitTreeResponseModel.GitTreeObject(
|
||||
sha="my-sha-3",
|
||||
path="file3",
|
||||
mode="100644",
|
||||
type="blob",
|
||||
size=456,
|
||||
url="https://api.github.com/repos/octocat/Hello-World/git/blobs/my-sha-3",
|
||||
),
|
||||
"path/to/file3",
|
||||
),
|
||||
(
|
||||
GitTreeResponseModel.GitTreeObject(
|
||||
sha="my-sha-4",
|
||||
path="file4",
|
||||
mode="100644",
|
||||
type="blob",
|
||||
size=941,
|
||||
url="https://api.github.com/repos/octocat/Hello-World/git/blobs/my-sha-4",
|
||||
),
|
||||
"path/to/file4",
|
||||
),
|
||||
]
|
||||
|
||||
it = BufferedGitBlobDataIterator(
|
||||
blobs_and_paths,
|
||||
github_client,
|
||||
owner,
|
||||
repo,
|
||||
loop,
|
||||
buffer_size=3,
|
||||
verbose=False,
|
||||
)
|
||||
assert len(it._buffer) == 0
|
||||
assert it._index == 0
|
||||
assert it._buffer_size == 3
|
||||
assert await it.__anext__() == (
|
||||
f"base64-decoded string blob content {owner}/{repo}/my-sha-1",
|
||||
"path/file1",
|
||||
)
|
||||
assert len(it._buffer) == 2
|
||||
assert it._index == 1
|
||||
assert await it.__anext__() == (
|
||||
f"base64-decoded string blob content {owner}/{repo}/my-sha-2",
|
||||
"path/file2",
|
||||
)
|
||||
assert len(it._buffer) == 1
|
||||
assert it._index == 2
|
||||
assert await it.__anext__() == (
|
||||
f"base64-decoded string blob content {owner}/{repo}/my-sha-3",
|
||||
"path/to/file3",
|
||||
)
|
||||
assert len(it._buffer) == 0
|
||||
assert it._index == 3
|
||||
assert await it.__anext__() == (
|
||||
f"base64-decoded string blob content {owner}/{repo}/my-sha-4",
|
||||
"path/to/file4",
|
||||
)
|
||||
assert len(it._buffer) == 0
|
||||
assert it._index == 4
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await it.__anext__()
|
||||
|
||||
|
||||
########################
|
||||
|
||||
## GithubClient tests ##
|
||||
from loader_hub.github_repo import GithubClient, GithubRepositoryReader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -210,8 +39,7 @@ async def test_github_client(github_client):
|
||||
== f"https://api.github.com/repos/{owner}/{repo}/branches/{branch}"
|
||||
), "Branch self link is incorrect"
|
||||
assert (
|
||||
branch_data._links.html
|
||||
== f"https://github.com/{owner}/{repo}/tree/{branch}"
|
||||
branch_data._links.html == f"https://github.com/{owner}/{repo}/tree/{branch}"
|
||||
), "Branch html link is incorrect"
|
||||
|
||||
# test get_commit
|
||||
@ -223,16 +51,12 @@ async def test_github_client(github_client):
|
||||
), "Commit url is incorrect"
|
||||
|
||||
# test get_tree
|
||||
tree_data = await github_client.get_tree(
|
||||
owner, repo, commit_data.commit.tree.sha
|
||||
)
|
||||
tree_data = await github_client.get_tree(owner, repo, commit_data.commit.tree.sha)
|
||||
assert (
|
||||
tree_data.url
|
||||
== f"https://api.github.com/repos/{owner}/{repo}/git/trees/{commit_data.commit.tree.sha}"
|
||||
), "Tree url is incorrect"
|
||||
assert (
|
||||
tree_data.sha == commit_data.commit.tree.sha
|
||||
), "Tree sha is incorrect"
|
||||
assert tree_data.sha == commit_data.commit.tree.sha, "Tree sha is incorrect"
|
||||
print(tree_data.tree[0].sha)
|
||||
assert 1 == 1
|
||||
|
||||
@ -271,9 +95,7 @@ async def test_github_client(github_client):
|
||||
][0]
|
||||
|
||||
# test get_blob
|
||||
blob_data = await github_client.get_blob(
|
||||
owner, repo, test_requirements_txt.sha
|
||||
)
|
||||
blob_data = await github_client.get_blob(owner, repo, test_requirements_txt.sha)
|
||||
assert blob_data.encoding == "base64", "Blob encoding is incorrect"
|
||||
assert (
|
||||
blob_data.url
|
||||
@ -307,557 +129,3 @@ isort==5.11.4
|
||||
filter(lambda x: x != "", expected_decoded_blob_content.splitlines()),
|
||||
):
|
||||
assert dbc[0] == dbc[1], f"{dbc[0]} is not equal to {dbc[1]}"
|
||||
|
||||
|
||||
class TestGithubRepositoryReader(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.github_client = MagicMock()
|
||||
self.owner = "owner"
|
||||
self.repo = "repo"
|
||||
self.reader = GithubRepositoryReader(
|
||||
self.github_client, self.owner, self.repo
|
||||
)
|
||||
|
||||
def test__check_filter_directories(self):
|
||||
tree_obj_path = "path/to/some/file.py"
|
||||
self.reader._filter_directories = (
|
||||
["path/to"],
|
||||
GithubRepositoryReader.FilterType.INCLUDE,
|
||||
)
|
||||
self.assertTrue(self.reader._check_filter_directories(tree_obj_path))
|
||||
|
||||
self.reader._filter_directories = (
|
||||
["path/to"],
|
||||
GithubRepositoryReader.FilterType.EXCLUDE,
|
||||
)
|
||||
self.assertFalse(self.reader._check_filter_directories(tree_obj_path))
|
||||
|
||||
def test__check_filter_file_extensions(self):
|
||||
tree_obj_path = "path/to/some/file.py"
|
||||
self.reader._filter_file_extensions = (
|
||||
[".py"],
|
||||
GithubRepositoryReader.FilterType.INCLUDE,
|
||||
)
|
||||
self.assertTrue(
|
||||
self.reader._check_filter_file_extensions(tree_obj_path)
|
||||
)
|
||||
|
||||
self.reader._filter_file_extensions = (
|
||||
[".txt"],
|
||||
GithubRepositoryReader.FilterType.EXCLUDE,
|
||||
)
|
||||
self.assertTrue(
|
||||
self.reader._check_filter_file_extensions(tree_obj_path)
|
||||
)
|
||||
|
||||
def test__allow_tree_obj(self):
|
||||
tree_obj_paths = [
|
||||
"src/file.py",
|
||||
"src/file.txt",
|
||||
"src/Path.To.Folder/file1.js",
|
||||
"src/Path.To.Folder/file2.cpp",
|
||||
"src/Path.To.Folder/file4.rs",
|
||||
"src/Path.To.Folder/file5.ts",
|
||||
"src/Path.To.Folder/file6.h",
|
||||
"src/Path.To.Folder/file7.c",
|
||||
"src/Path.To.Folder/file8.java",
|
||||
"src/dir1/file.js",
|
||||
"src/assets/file.png",
|
||||
"src/assets/file.jpg",
|
||||
"src/assets/file.jpeg",
|
||||
"src/assets/file.gif",
|
||||
"src/assets/file.svg",
|
||||
"src/assets/file.ico",
|
||||
"src/documents/file.pdf",
|
||||
"src/documents/file.doc",
|
||||
"src/documents/file.docx",
|
||||
"src/documents/file.xls",
|
||||
"src/documents/file.xlsx",
|
||||
"src/documents/file.ppt",
|
||||
"src/documents/file.pptx",
|
||||
"src/documents/file.odt",
|
||||
"src/documents/file.ods",
|
||||
"src/dir2/subdir/file.cpp",
|
||||
"src/dir2/subdir/file.c",
|
||||
"src/dir2/subdir/file.h",
|
||||
"src/dir2/subdir/file.hpp",
|
||||
"src/dir2/subdir/file.java",
|
||||
"src/dir2/foo.cc",
|
||||
"src/dir2/foo.svg",
|
||||
"src/dir2/subdir/file.go",
|
||||
"src/sub/folder/loading.svg",
|
||||
"src/sub/folder/loading.ico",
|
||||
]
|
||||
self.reader._filter_directories = (
|
||||
["src/assets", "src/documents"],
|
||||
GithubRepositoryReader.FilterType.EXCLUDE,
|
||||
)
|
||||
self.reader._filter_file_extensions = (
|
||||
[".svg", ".ico", ".cpp", ".c", ".h"],
|
||||
GithubRepositoryReader.FilterType.EXCLUDE,
|
||||
)
|
||||
|
||||
expected_tree_obj_paths = [
|
||||
"src/file.py",
|
||||
"src/file.txt",
|
||||
"src/Path.To.Folder/file1.js",
|
||||
"src/Path.To.Folder/file4.rs",
|
||||
"src/Path.To.Folder/file5.ts",
|
||||
"src/Path.To.Folder/file8.java",
|
||||
"src/dir1/file.js",
|
||||
"src/dir2/subdir/file.hpp",
|
||||
"src/dir2/subdir/file.java",
|
||||
"src/dir2/foo.cc",
|
||||
"src/dir2/subdir/file.go",
|
||||
]
|
||||
|
||||
actual_tree_obj_paths = [
|
||||
tree_obj_path
|
||||
for tree_obj_path in tree_obj_paths
|
||||
if self.reader._allow_tree_obj(tree_obj_path)
|
||||
]
|
||||
|
||||
print(f"Expected: {expected_tree_obj_paths}")
|
||||
print(f"Actual: {actual_tree_obj_paths}")
|
||||
self.assertCountEqual(
|
||||
expected_tree_obj_paths, actual_tree_obj_paths
|
||||
), "Tree object paths are incorrect"
|
||||
|
||||
self.reader._filter_directories = (
|
||||
["src/dir2/subdir", "src/documents", "src/Path.To.Folder"],
|
||||
GithubRepositoryReader.FilterType.INCLUDE,
|
||||
)
|
||||
self.reader._filter_file_extensions = (
|
||||
[".png", ".svg", ".ico", "jpg", ".java"],
|
||||
GithubRepositoryReader.FilterType.EXCLUDE,
|
||||
)
|
||||
|
||||
expected_tree_obj_paths = [
|
||||
"src/Path.To.Folder/file1.js",
|
||||
"src/Path.To.Folder/file2.cpp",
|
||||
"src/Path.To.Folder/file4.rs",
|
||||
"src/Path.To.Folder/file5.ts",
|
||||
"src/Path.To.Folder/file6.h",
|
||||
"src/Path.To.Folder/file7.c",
|
||||
"src/documents/file.pdf",
|
||||
"src/documents/file.doc",
|
||||
"src/documents/file.docx",
|
||||
"src/documents/file.xls",
|
||||
"src/documents/file.xlsx",
|
||||
"src/documents/file.ppt",
|
||||
"src/documents/file.pptx",
|
||||
"src/documents/file.odt",
|
||||
"src/documents/file.ods",
|
||||
"src/dir2/subdir/file.cpp",
|
||||
"src/dir2/subdir/file.c",
|
||||
"src/dir2/subdir/file.h",
|
||||
"src/dir2/subdir/file.hpp",
|
||||
"src/dir2/subdir/file.go",
|
||||
]
|
||||
|
||||
actual_tree_obj_paths = [
|
||||
tree_obj_path
|
||||
for tree_obj_path in tree_obj_paths
|
||||
if self.reader._allow_tree_obj(tree_obj_path)
|
||||
]
|
||||
|
||||
print(f"Expected: {expected_tree_obj_paths}")
|
||||
print(f"Actual: {actual_tree_obj_paths}")
|
||||
self.assertCountEqual(
|
||||
expected_tree_obj_paths, actual_tree_obj_paths
|
||||
), "Tree object paths are incorrect"
|
||||
|
||||
|
||||
## Test GithubRepositoryReader's _recurse_tree method
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test__recurse_tree():
|
||||
github_client = MagicMock()
|
||||
owner = "owner"
|
||||
repo = "repo"
|
||||
reader = GithubRepositoryReader(github_client, owner, repo)
|
||||
|
||||
# return value for the first call to get_tree (the root tree)
|
||||
tree_sha = "1234"
|
||||
tree_data = GitTreeResponseModel(
|
||||
sha=tree_sha,
|
||||
tree=[
|
||||
GitTreeResponseModel.GitTreeObject(
|
||||
type="blob",
|
||||
path="file1.py",
|
||||
sha="5678",
|
||||
mode="100644",
|
||||
size=1111,
|
||||
url="https://api.github.com/repos/owner/repo/git/blobs/5678",
|
||||
),
|
||||
GitTreeResponseModel.GitTreeObject(
|
||||
type="tree",
|
||||
path="folder1",
|
||||
sha="91011",
|
||||
mode="040000",
|
||||
size=None,
|
||||
url="https://api.github.com/repos/owner/repo/git/blobs/91011",
|
||||
),
|
||||
GitTreeResponseModel.GitTreeObject(
|
||||
type="blob",
|
||||
path="file2.py",
|
||||
sha="1213",
|
||||
mode="100644",
|
||||
size=3333,
|
||||
url="https://api.github.com/repos/owner/repo/git/blobs/1213",
|
||||
),
|
||||
GitTreeResponseModel.GitTreeObject(
|
||||
type="blob",
|
||||
path="file3.py",
|
||||
sha="1415",
|
||||
mode="100644",
|
||||
size=4444,
|
||||
url="https://api.github.com/repos/owner/repo/git/blobs/1415",
|
||||
),
|
||||
],
|
||||
truncated=False,
|
||||
url="https://api.github.com/repos/owner/repo/git/trees/1234",
|
||||
)
|
||||
|
||||
def get_tree_side_effect(owner, repo, sha):
|
||||
if sha == tree_sha:
|
||||
return tree_data
|
||||
elif sha == "91011":
|
||||
# return value for the second call to get_tree (the tree for folder1)
|
||||
return GitTreeResponseModel(
|
||||
tree=[
|
||||
GitTreeResponseModel.GitTreeObject(
|
||||
type="blob",
|
||||
path="file4.py",
|
||||
sha="1617",
|
||||
mode="100644",
|
||||
size=6666,
|
||||
url="https://api.github.com/repos/owner/repo/git/blobs/1617",
|
||||
),
|
||||
GitTreeResponseModel.GitTreeObject(
|
||||
type="tree",
|
||||
path="folder3",
|
||||
sha="1819",
|
||||
mode="040000",
|
||||
size=None,
|
||||
url="https://api.github.com/repos/owner/repo/git/blobs/1819",
|
||||
),
|
||||
],
|
||||
sha="91011",
|
||||
truncated=False,
|
||||
url="https://api.github.com/repos/owner/repo/git/trees/91011",
|
||||
)
|
||||
elif sha == "1819":
|
||||
# return value for the third call to get_tree (the tree for folder3)
|
||||
return GitTreeResponseModel(
|
||||
tree=[
|
||||
GitTreeResponseModel.GitTreeObject(
|
||||
type="blob",
|
||||
path="file5.py",
|
||||
sha="2021",
|
||||
mode="100644",
|
||||
size=8888,
|
||||
url="https://api.github.com/repos/owner/repo/git/blobs/2021",
|
||||
),
|
||||
],
|
||||
sha="1819",
|
||||
truncated=False,
|
||||
url="https://api.github.com/repos/owner/repo/git/trees/1819",
|
||||
)
|
||||
else:
|
||||
raise httpx.HTTPError(
|
||||
f"404 Client Error: Not Found for url: https://api.github.com/repos/{owner}/{repo}/git/trees/{sha}"
|
||||
)
|
||||
|
||||
github_client.get_tree = AsyncMock(side_effect=get_tree_side_effect)
|
||||
|
||||
expected_blobs_and_full_paths = [
|
||||
(
|
||||
GitTreeResponseModel.GitTreeObject(
|
||||
type="blob",
|
||||
path="file1.py",
|
||||
sha="5678",
|
||||
mode="100644",
|
||||
size=1111,
|
||||
url="https://api.github.com/repos/owner/repo/git/blobs/5678",
|
||||
),
|
||||
"file1.py",
|
||||
),
|
||||
(
|
||||
GitTreeResponseModel.GitTreeObject(
|
||||
type="blob",
|
||||
path="file2.py",
|
||||
sha="1213",
|
||||
mode="100644",
|
||||
size=3333,
|
||||
url="https://api.github.com/repos/owner/repo/git/blobs/1213",
|
||||
),
|
||||
"file2.py",
|
||||
),
|
||||
(
|
||||
GitTreeResponseModel.GitTreeObject(
|
||||
type="blob",
|
||||
path="file4.py",
|
||||
sha="1617",
|
||||
mode="100644",
|
||||
size=6666,
|
||||
url="https://api.github.com/repos/owner/repo/git/blobs/1617",
|
||||
),
|
||||
"folder1/file4.py",
|
||||
),
|
||||
(
|
||||
GitTreeResponseModel.GitTreeObject(
|
||||
type="blob",
|
||||
path="file5.py",
|
||||
sha="2021",
|
||||
mode="100644",
|
||||
size=8888,
|
||||
url="https://api.github.com/repos/owner/repo/git/blobs/2021",
|
||||
),
|
||||
"folder1/folder3/file5.py",
|
||||
),
|
||||
(
|
||||
GitTreeResponseModel.GitTreeObject(
|
||||
type="blob",
|
||||
path="file3.py",
|
||||
sha="1415",
|
||||
mode="100644",
|
||||
size=4444,
|
||||
url="https://api.github.com/repos/owner/repo/git/blobs/1415",
|
||||
),
|
||||
"file3.py",
|
||||
),
|
||||
]
|
||||
|
||||
blobs_and_full_paths = await reader._recurse_tree(tree_sha)
|
||||
|
||||
# make sure get_tree was called the expected number of times
|
||||
assert (
|
||||
github_client.get_tree.call_count == 3
|
||||
), "There should be only 3 calls to get_tree (one for the root tree, and one for each subfolder folder1 and folder3)"
|
||||
|
||||
# sort the expected and actual results by full path so we can compare them
|
||||
for (blob, full_path), (expected_blob, expected_full_path) in zip(
|
||||
sorted(blobs_and_full_paths, key=lambda x: x[1]),
|
||||
sorted(expected_blobs_and_full_paths, key=lambda x: x[1]),
|
||||
):
|
||||
assert (
|
||||
blob == expected_blob
|
||||
), "actual blob info does not match expected blob info"
|
||||
assert (
|
||||
full_path == expected_full_path
|
||||
), "actual full path does not match expected full path"
|
||||
|
||||
with pytest.raises(
|
||||
httpx.HTTPError,
|
||||
match="404 Client Error: Not Found for url: https://api.github.com/repos/owner/repo/git/trees/12345",
|
||||
):
|
||||
await reader._recurse_tree("12345")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test__generate_documents():
|
||||
github_client = MagicMock()
|
||||
owner = "owner"
|
||||
repo = "repo"
|
||||
reader = GithubRepositoryReader(
|
||||
github_client=github_client,
|
||||
owner=owner,
|
||||
repo=repo,
|
||||
use_parser=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
blobs_and_paths: List[Tuple[GitTreeResponseModel.GitTreeObject, str]] = [
|
||||
(
|
||||
GitTreeResponseModel.GitTreeObject(
|
||||
type="blob",
|
||||
path="file1.py",
|
||||
sha="5678",
|
||||
mode="100644",
|
||||
size=1111,
|
||||
url="https://api.github.com/repos/owner/repo/git/blobs/5678",
|
||||
),
|
||||
"file1.py",
|
||||
),
|
||||
(
|
||||
GitTreeResponseModel.GitTreeObject(
|
||||
type="blob",
|
||||
path="file2.ts",
|
||||
sha="1213",
|
||||
mode="100644",
|
||||
size=3333,
|
||||
url="https://api.github.com/repos/owner/repo/git/blobs/1213",
|
||||
),
|
||||
"folder1/file2.ts",
|
||||
),
|
||||
(
|
||||
GitTreeResponseModel.GitTreeObject(
|
||||
type="blob",
|
||||
path="file3.rs",
|
||||
sha="1415",
|
||||
mode="100644",
|
||||
size=4444,
|
||||
url="https://api.github.com/repos/owner/repo/git/blobs/1415",
|
||||
),
|
||||
"folder1/folder2/file3.rs",
|
||||
),
|
||||
(
|
||||
GitTreeResponseModel.GitTreeObject(
|
||||
type="blob",
|
||||
path="file4.cc",
|
||||
sha="1617",
|
||||
mode="100644",
|
||||
size=6666,
|
||||
url="https://api.github.com/repos/owner/repo/git/blobs/1617",
|
||||
),
|
||||
"folder1/folder2/folder3/file4.cc",
|
||||
),
|
||||
(
|
||||
GitTreeResponseModel.GitTreeObject( # this file should not end up in the generated documents since it should fail to decode as utf-8
|
||||
type="blob",
|
||||
path="file5.png",
|
||||
sha="2021",
|
||||
mode="100644",
|
||||
size=8888,
|
||||
url="https://api.github.com/repos/owner/repo/git/blobs/2021",
|
||||
),
|
||||
"folder1/folder2/folder3/file5.png",
|
||||
),
|
||||
]
|
||||
|
||||
async def get_blob_side_effect(owner: str, repo: str, sha: str):
|
||||
if sha == "5678":
|
||||
return GitBlobResponseModel(
|
||||
content="cHJpbnQoJ2hlbGxvIHdvcmxkJyk=",
|
||||
encoding="base64",
|
||||
sha="5678",
|
||||
size=1111,
|
||||
url="https://api.github.com/repos/owner/repo/git/blobs/5678",
|
||||
node_id="1234",
|
||||
)
|
||||
elif sha == "1213":
|
||||
return GitBlobResponseModel(
|
||||
content="Y29uc29sZS5sb2coJ2hlbGxvIHdvcmxkJyk=",
|
||||
encoding="base64",
|
||||
sha="1213",
|
||||
size=3333,
|
||||
url="https://api.github.com/repos/owner/repo/git/blobs/1213",
|
||||
node_id="2345",
|
||||
)
|
||||
elif sha == "1415":
|
||||
return GitBlobResponseModel(
|
||||
content="cHJpbnRsbiEoImhlbGxvIHdvcmxkIik=",
|
||||
encoding="base64",
|
||||
sha="1415",
|
||||
size=4444,
|
||||
url="https://api.github.com/repos/owner/repo/git/blobs/1415",
|
||||
node_id="3456",
|
||||
)
|
||||
elif sha == "1617":
|
||||
return GitBlobResponseModel(
|
||||
content="c3RkOjpjb3V0IDw8ICJoZWxsbyB3b3JsZCIgPDwgc3RkOjplbmRsOw==",
|
||||
encoding="base64",
|
||||
sha="1617",
|
||||
size=6666,
|
||||
url="https://api.github.com/repos/owner/repo/git/blobs/1617",
|
||||
node_id="4567",
|
||||
)
|
||||
elif sha == "2021":
|
||||
return GitBlobResponseModel(
|
||||
content="iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==",
|
||||
encoding="base64",
|
||||
sha="2021",
|
||||
size=8888,
|
||||
url="https://api.github.com/repos/owner/repo/git/blobs/2021",
|
||||
node_id="5678",
|
||||
)
|
||||
else:
|
||||
raise httpx.HTTPError(
|
||||
f"404 Client Error: Not Found for url: https://api.github.com/repos/{owner}/{repo}/git/blobs/{sha}"
|
||||
)
|
||||
|
||||
github_client.get_blob = AsyncMock(side_effect=get_blob_side_effect)
|
||||
|
||||
documents = await reader._generate_documents(blobs_and_paths)
|
||||
|
||||
assert (
|
||||
github_client.get_blob.await_count == 5
|
||||
), "get_blob should be awaited 5 times for each blob"
|
||||
|
||||
github_client.get_blob.assert_has_awaits(
|
||||
[
|
||||
call(owner, repo, "5678"),
|
||||
call(owner, repo, "1213"),
|
||||
call(owner, repo, "1415"),
|
||||
call(owner, repo, "1617"),
|
||||
call(owner, repo, "2021"),
|
||||
]
|
||||
), "get_blob should be awaited with the correct arguments"
|
||||
|
||||
assert (
|
||||
len(documents) == 4
|
||||
), "There should be 4 documents generated from the blobs_and_paths"
|
||||
|
||||
expected_documents = [
|
||||
Document(
|
||||
text="print('hello world')",
|
||||
extra_info={
|
||||
"file_path": "file1.py",
|
||||
"file_name": "file1.py",
|
||||
},
|
||||
),
|
||||
Document(
|
||||
text="console.log('hello world')",
|
||||
extra_info={
|
||||
"file_path": "folder1/file2.ts",
|
||||
"file_name": "file2.ts",
|
||||
},
|
||||
),
|
||||
Document(
|
||||
text='println!("hello world")',
|
||||
extra_info={
|
||||
"file_path": "folder1/folder2/file3.rs",
|
||||
"file_name": "file3.rs",
|
||||
},
|
||||
),
|
||||
Document(
|
||||
text='std::cout << "hello world" << std::endl;',
|
||||
extra_info={
|
||||
"file_path": "folder1/folder2/folder3/file4.cc",
|
||||
"file_name": "file4.cc",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
for document, expected_document in zip(
|
||||
sorted(documents, key=lambda x: x.extra_info["file_path"]),
|
||||
sorted(expected_documents, key=lambda x: x.extra_info["file_path"]),
|
||||
):
|
||||
assert (
|
||||
document.text == expected_document.text
|
||||
), "The text of the document should be the decoded content of the blob"
|
||||
assert (
|
||||
document.extra_info == expected_document.extra_info
|
||||
), "The extra_info of the document should be the file_path and file_name"
|
||||
|
||||
with pytest.raises(
|
||||
httpx.HTTPError,
|
||||
match="404 Client Error: Not Found for url: https://api.github.com/repos/owner/repo/git/blobs/12345",
|
||||
):
|
||||
await reader._generate_documents(
|
||||
[
|
||||
(
|
||||
GitTreeResponseModel.GitTreeObject(
|
||||
type="blob",
|
||||
path="file1.py",
|
||||
sha="12345",
|
||||
mode="100644",
|
||||
size=1111,
|
||||
url="https://api.github.com/repos/owner/repo/git/blobs/12345",
|
||||
),
|
||||
"file1.py",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user