mirror of
https://github.com/run-llama/llama-hub.git
synced 2025-12-27 06:59:06 +00:00
Add GitHub Repository Reader (#34)
* add github repository, test a new way to download loader * test imports when downloaded from gpt_index * Refactor(Github Repo): Move github_client and utils to modules * Moved github_client.py and utils.py from loader_hub/github_repo to modules/github_repo * Updated import statements in base.py to reflect the new location * temp * Refactor(GithubRepositoryReader): Add github_client argument - Add github_client argument to GithubRepositoryReader constructor - Set default value for github_client argument - Update docstring to reflect changes * Refactor(Github Repo): Update init file - Remove imports of base, github_client and utils - Add imports of GithubRepositoryReader and GithubClient - Update __all__ to include the new imports * Fix(library): Update library.json - Updated library.json to include __init__.py file * Refactor(GithubRepositoryReader): Add filter for directories and files - Add filter for directories and files in GithubRepositoryReader - Ignore directories and files that do not pass the filter - Print out if directory or file is ignored due to filter * Refactor(BaseReader): Check filter files - Refactor `_check_filter_files` to `_check_filter_file_extensions` in `BaseReader` - Ignoring files due to filter * Docs(FilterType): Add documentation for FilterType enum - Add documentation for FilterType enum - Explain what the enum is used for - Describe the attributes of the enum * Add(GPT Index): Add GPT Index example Add GPT Index example to README - Set OPENAI_API_KEY environment variable - Download GithubRepositoryReader module - Create GithubClient and GithubRepositoryReader - Load data from Github Repository - Create GPTSimpleVectorIndex - Query the index * Add(GPT Index): Add GPT Index example Add GPT Index example to README - Set OPENAI_API_KEY environment variable - Download GithubRepositoryReader module - Create GithubClient and GithubRepositoryReader - Load data from Github Repository - Create GPTSimpleVectorIndex - Query the index * Add(GPT Index): Add GPT Index example Add GPT Index example to README - Set OPENAI_API_KEY environment variable - Download GithubRepositoryReader module - Create GithubClient and GithubRepositoryReader - Load data from Github Repository - Create GPTSimpleVectorIndex - Query the index * change the import path for extras * change import path for extra files to absolute * Add test for GithubClient currently not using mocks which is not ideal * Update test_github_reader.py * Update test_github_reader.py --------- Co-authored-by: Jesse Zhang <jessetanzhang@gmail.com>
This commit is contained in:
parent
457e7888e9
commit
5a27264db1
@ -1 +1 @@
|
||||
"""Init file."""
|
||||
"""Init file."""
|
||||
83
loader_hub/github_repo/README.md
Normal file
83
loader_hub/github_repo/README.md
Normal file
@ -0,0 +1,83 @@
|
||||
# Github Repository Loader
|
||||
|
||||
This loader takes in `owner`, `repo`, `branch`, `commit` and other optional parameters such as for filtering dicrectories or only allowing some files with given extensions etc. It then fetches all the contents of the GitHub repository.
|
||||
|
||||
As a prerequisite, you will need to generate a person access token. See [here](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token) for instructions.
|
||||
|
||||
## Usage
|
||||
|
||||
To use this loader, you simply need to pass in the `owner` and `repo` and either `branch` or `commit` for example, you can `owner = jerryjliu` and `repo = gpt_index` and also either branch or commit `branch = main` or `commit = a6c89159bf8e7086bea2f4305cff3f0a4102e370`
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
from gpt_index import download_loader
|
||||
download_loader("GithubRepositoryReader")
|
||||
|
||||
from modules.github_repo import GithubRepositoryReader, GithubClient
|
||||
|
||||
github_client = GithubClient(os.getenv("GITHUB_TOKEN"))
|
||||
loader = GithubRepositoryReader(
|
||||
github_client,
|
||||
owner = "jerryjliu",
|
||||
repo = "gpt_index",
|
||||
filter_directories = (["gpt_index", "docs"], GithubRepositoryReader.FilterType.INCLUDE),
|
||||
filter_file_extensions = ([".py"], GithubRepositoryReader.FilterType.INCLUDE),
|
||||
verbose = True,
|
||||
concurrent_requests = 10,
|
||||
)
|
||||
|
||||
docs_branch = loader.load_data(branch="main")
|
||||
docs_commit = loader.load_data(commit="a6c89159bf8e7086bea2f4305cff3f0a4102e370")
|
||||
|
||||
for doc in docs:
|
||||
print(doc.extra_info)
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
This loader designed to be used as a way to load data into [GPT Index](https://github.com/jerryjliu/gpt_index/tree/main/gpt_index) and/or subsequently used as a Tool in a [LangChain](https://github.com/hwchase17/langchain) Agent.
|
||||
|
||||
### GPT Index
|
||||
|
||||
```python
|
||||
import pickle
|
||||
import os
|
||||
|
||||
assert (
|
||||
os.getenv("OPENAI_API_KEY") is not None
|
||||
), "Please set the OPENAI_API_KEY environment variable."
|
||||
|
||||
from gpt_index import download_loader
|
||||
download_loader("GithubRepositoryReader")
|
||||
|
||||
from modules.github_repo import GithubClient, GithubRepositoryReader
|
||||
|
||||
docs = None
|
||||
|
||||
docs = None
|
||||
if os.path.exists("docs.pkl"):
|
||||
with open("docs.pkl", "rb") as f:
|
||||
docs = pickle.load(f)
|
||||
|
||||
if docs is None:
|
||||
github_client = GithubClient(os.getenv("GITHUB_TOKEN"))
|
||||
loader = GithubRepositoryReader(
|
||||
github_client,
|
||||
owner = "jerryjliu",
|
||||
repo = "gpt_index",
|
||||
filter_directories = (["gpt_index", "docs"], GithubRepositoryReader.FilterType.INCLUDE),
|
||||
filter_file_extensions = ([".py"], GithubRepositoryReader.FilterType.INCLUDE),
|
||||
verbose = True,
|
||||
concurrent_requests = 10,
|
||||
)
|
||||
|
||||
docs = loader.load_data(branch="main")
|
||||
|
||||
with open("docs.pkl", "wb") as f:
|
||||
pickle.dump(docs, f)
|
||||
|
||||
index = GPTSimpleVectorIndex(docs)
|
||||
|
||||
index.query("Explain each GPTIndex class?")
|
||||
```
|
||||
6
loader_hub/github_repo/__init__.py
Normal file
6
loader_hub/github_repo/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
"""Init file."""
|
||||
|
||||
from .base import GithubRepositoryReader
|
||||
from .github_client import GithubClient
|
||||
|
||||
__all__ = ["GithubRepositoryReader", "GithubClient"]
|
||||
534
loader_hub/github_repo/base.py
Normal file
534
loader_hub/github_repo/base.py
Normal file
@ -0,0 +1,534 @@
|
||||
"""
|
||||
Github repository reader.
|
||||
|
||||
Retrieves the contents of a Github repository and returns a list of documents.
|
||||
The documents are either the contents of the files in the repository or
|
||||
the text extracted from the files using the parser.
|
||||
"""
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from dataclasses_json import DataClassJsonMixin
|
||||
import asyncio
|
||||
import base64
|
||||
import binascii
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import tempfile
|
||||
import enum
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
from gpt_index.readers.base import BaseReader
|
||||
from gpt_index.readers.file.base import DEFAULT_FILE_EXTRACTOR
|
||||
|
||||
from gpt_index.readers.llamahub_modules.github_repo.github_client import (
|
||||
BaseGithubClient,
|
||||
GitBranchResponseModel,
|
||||
GitCommitResponseModel,
|
||||
GithubClient,
|
||||
GitTreeResponseModel,
|
||||
)
|
||||
|
||||
from gpt_index.readers.llamahub_modules.github_repo.utils import (
|
||||
BufferedGitBlobDataIterator,
|
||||
print_if_verbose,
|
||||
get_file_extension,
|
||||
)
|
||||
from gpt_index.readers.schema.base import Document
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GithubRepositoryReader(BaseReader):
|
||||
"""
|
||||
Github repository reader.
|
||||
|
||||
Retrieves the contents of a Github repository and returns a list of documents.
|
||||
The documents are either the contents of the files in the repository or the text
|
||||
extracted from the files using the parser.
|
||||
|
||||
Examples:
|
||||
>>> reader = GithubRepositoryReader("owner", "repo")
|
||||
>>> branch_documents = reader.load_data(branch="branch")
|
||||
>>> commit_documents = reader.load_data(commit_sha="commit_sha")
|
||||
|
||||
"""
|
||||
|
||||
class FilterType(enum.Enum):
|
||||
"""
|
||||
Filter type.
|
||||
|
||||
Used to determine whether the filter is inclusive or exclusive.
|
||||
|
||||
Attributes:
|
||||
- EXCLUDE: Exclude the files in the directories or with the extensions.
|
||||
- INCLUDE: Include only the files in the directories or with the extensions.
|
||||
"""
|
||||
EXCLUDE = enum.auto()
|
||||
INCLUDE = enum.auto()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
github_client: BaseGithubClient,
|
||||
owner: str,
|
||||
repo: str,
|
||||
use_parser: bool = True,
|
||||
verbose: bool = False,
|
||||
concurrent_requests: int = 5,
|
||||
filter_directories: Optional[Tuple[List[str], FilterType]] = None,
|
||||
filter_file_extensions: Optional[Tuple[List[str], FilterType]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize params.
|
||||
|
||||
Args:
|
||||
- github_client (BaseGithubClient): Github client.
|
||||
- owner (str): Owner of the repository.
|
||||
- repo (str): Name of the repository.
|
||||
- use_parser (bool): Whether to use the parser to extract
|
||||
the text from the files.
|
||||
- verbose (bool): Whether to print verbose messages.
|
||||
- concurrent_requests (int): Number of concurrent requests to
|
||||
make to the Github API.
|
||||
- filter_directories (Optional[Tuple[List[str], FilterType]]): Tuple
|
||||
containing a list of directories and a FilterType. If the FilterType
|
||||
is INCLUDE, only the files in the directories in the list will be
|
||||
included. If the FilterType is EXCLUDE, the files in the directories
|
||||
in the list will be excluded.
|
||||
- filter_file_extensions (Optional[Tuple[List[str], FilterType]]): Tuple
|
||||
containing a list of file extensions and a FilterType. If the
|
||||
FilterType is INCLUDE, only the files with the extensions in the list
|
||||
will be included. If the FilterType is EXCLUDE, the files with the
|
||||
extensions in the list will be excluded.
|
||||
|
||||
Raises:
|
||||
- `ValueError`: If the github_token is not provided and
|
||||
the GITHUB_TOKEN environment variable is not set.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self._owner = owner
|
||||
self._repo = repo
|
||||
self._use_parser = use_parser
|
||||
self._verbose = verbose
|
||||
self._concurrent_requests = concurrent_requests
|
||||
self._filter_directories = filter_directories
|
||||
self._filter_file_extensions = filter_file_extensions
|
||||
|
||||
# Set up the event loop
|
||||
try:
|
||||
self._loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
# If there is no running loop, create a new one
|
||||
self._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
|
||||
self._github_client = github_client
|
||||
|
||||
def _check_filter_directories(self, tree_obj_path: str) -> bool:
|
||||
"""
|
||||
Check if a tree object should be allowed based on the directories.
|
||||
|
||||
:param `tree_obj_path`: path of the tree object i.e. 'gpt_index/readers'
|
||||
|
||||
:return: True if the tree object should be allowed, False otherwise
|
||||
"""
|
||||
filter_directories, filter_type = self._filter_directories
|
||||
print_if_verbose(
|
||||
self._verbose,
|
||||
f"Checking {tree_obj_path} whether to {filter_type} it"
|
||||
+ f" based on the filter directories: {filter_directories}",
|
||||
)
|
||||
|
||||
if filter_type == self.FilterType.EXCLUDE:
|
||||
return not any(
|
||||
tree_obj_path.startswith(directory)
|
||||
or directory.startswith(tree_obj_path)
|
||||
for directory in filter_directories
|
||||
)
|
||||
elif filter_type == self.FilterType.INCLUDE:
|
||||
return any(
|
||||
tree_obj_path.startswith(directory)
|
||||
or directory.startswith(tree_obj_path)
|
||||
for directory in filter_directories
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown filter type: {filter_type}. "
|
||||
"Please use either 'ignore' or 'include'."
|
||||
)
|
||||
|
||||
def _check_filter_file_extensions(self, tree_obj_path: str) -> bool:
|
||||
"""
|
||||
Check if a tree object should be allowed based on the file extensions.
|
||||
|
||||
:param `tree_obj_path`: path of the tree object i.e. 'gpt_index/indices'
|
||||
|
||||
:return: True if the tree object should be allowed, False otherwise
|
||||
"""
|
||||
filter_file_extensions, filter_type = self._filter_file_extensions
|
||||
print_if_verbose(
|
||||
self._verbose,
|
||||
f"Checking {tree_obj_path} whether to {filter_type} it"
|
||||
+ f" based on the filter file extensions: {filter_file_extensions}",
|
||||
)
|
||||
|
||||
if filter_type == self.FilterType.EXCLUDE:
|
||||
return get_file_extension(tree_obj_path) not in filter_file_extensions
|
||||
elif filter_type == self.FilterType.INCLUDE:
|
||||
return get_file_extension(tree_obj_path) in filter_file_extensions
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown filter type: {filter_type}. "
|
||||
"Please use either 'ignore' or 'include'."
|
||||
)
|
||||
|
||||
def _allow_tree_obj(self, tree_obj_path: str) -> bool:
|
||||
"""
|
||||
Check if a tree object should be allowed.
|
||||
|
||||
:param `tree_obj_path`: path of the tree object
|
||||
|
||||
:return: True if the tree object should be allowed, False otherwise
|
||||
|
||||
"""
|
||||
if self._filter_directories is not None:
|
||||
return self._check_filter_directories(tree_obj_path)
|
||||
|
||||
if self._filter_file_extensions is not None:
|
||||
return self._check_filter_file_extensions(tree_obj_path)
|
||||
|
||||
return True
|
||||
|
||||
def _load_data_from_commit(self, commit_sha: str) -> List[Document]:
|
||||
"""
|
||||
Load data from a commit.
|
||||
|
||||
Loads github repository data from a specific commit sha.
|
||||
|
||||
:param `commit`: commit sha
|
||||
|
||||
:return: list of documents
|
||||
"""
|
||||
commit_response: GitCommitResponseModel = self._loop.run_until_complete(
|
||||
self._github_client.get_commit(self._owner, self._repo, commit_sha)
|
||||
)
|
||||
|
||||
tree_sha = commit_response.commit.tree.sha
|
||||
blobs_and_paths = self._loop.run_until_complete(self._recurse_tree(tree_sha))
|
||||
|
||||
print_if_verbose(self._verbose, f"got {len(blobs_and_paths)} blobs")
|
||||
|
||||
return self._loop.run_until_complete(
|
||||
self._generate_documents(blobs_and_paths=blobs_and_paths)
|
||||
)
|
||||
|
||||
def _load_data_from_branch(self, branch: str) -> List[Document]:
|
||||
"""
|
||||
Load data from a branch.
|
||||
|
||||
Loads github repository data from a specific branch.
|
||||
|
||||
:param `branch`: branch name
|
||||
|
||||
:return: list of documents
|
||||
"""
|
||||
branch_data: GitBranchResponseModel = self._loop.run_until_complete(
|
||||
self._github_client.get_branch(self._owner, self._repo, branch)
|
||||
)
|
||||
|
||||
tree_sha = branch_data.commit.commit.tree.sha
|
||||
blobs_and_paths = self._loop.run_until_complete(self._recurse_tree(tree_sha))
|
||||
|
||||
print_if_verbose(self._verbose, f"got {len(blobs_and_paths)} blobs")
|
||||
|
||||
return self._loop.run_until_complete(
|
||||
self._generate_documents(blobs_and_paths=blobs_and_paths)
|
||||
)
|
||||
|
||||
def load_data(
|
||||
self,
|
||||
commit_sha: Optional[str] = None,
|
||||
branch: Optional[str] = None,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Load data from a commit or a branch.
|
||||
|
||||
Loads github repository data from a specific commit sha or a branch.
|
||||
|
||||
:param `commit`: commit sha
|
||||
:param `branch`: branch name
|
||||
|
||||
:return: list of documents
|
||||
"""
|
||||
if commit_sha is not None and branch is not None:
|
||||
raise ValueError("You can only specify one of commit or branch.")
|
||||
|
||||
if commit_sha is None and branch is None:
|
||||
raise ValueError("You must specify one of commit or branch.")
|
||||
|
||||
if commit_sha is not None:
|
||||
return self._load_data_from_commit(commit_sha)
|
||||
|
||||
if branch is not None:
|
||||
return self._load_data_from_branch(branch)
|
||||
|
||||
raise ValueError("You must specify one of commit or branch.")
|
||||
|
||||
async def _recurse_tree(
|
||||
self, tree_sha: str, current_path: str = "", current_depth: int = 0
|
||||
) -> Any:
|
||||
"""
|
||||
Recursively get all blob tree objects in a tree.
|
||||
|
||||
And construct their full path relative to the root of the repository.
|
||||
(see GitTreeResponseModel.GitTreeObject in
|
||||
github_api_client.py for more information)
|
||||
|
||||
:param `tree_sha`: sha of the tree to recurse
|
||||
:param `current_path`: current path of the tree
|
||||
:param `current_depth`: current depth of the tree
|
||||
:return: list of tuples of
|
||||
(tree object, file's full path realtive to the root of the repo)
|
||||
"""
|
||||
blobs_and_full_paths: List[Tuple[GitTreeResponseModel.GitTreeObject, str]] = []
|
||||
print_if_verbose(
|
||||
self._verbose, "\t" * current_depth + f"current path: {current_path}"
|
||||
)
|
||||
|
||||
tree_data: GitTreeResponseModel = await self._github_client.get_tree(
|
||||
self._owner, self._repo, tree_sha
|
||||
)
|
||||
print_if_verbose(
|
||||
self._verbose, "\t" * current_depth + f"processing tree {tree_sha}"
|
||||
)
|
||||
for tree_obj in tree_data.tree:
|
||||
file_path = os.path.join(current_path, tree_obj.path)
|
||||
|
||||
if tree_obj.type == "tree":
|
||||
print_if_verbose(
|
||||
self._verbose,
|
||||
"\t" * current_depth + f"recursing into {tree_obj.path}",
|
||||
)
|
||||
if not self._check_filter_directories(file_path):
|
||||
print_if_verbose(
|
||||
self._verbose,
|
||||
"\t" * current_depth + f"ignoring directory {tree_obj.path} due to filter",
|
||||
)
|
||||
continue
|
||||
|
||||
blobs_and_full_paths.extend(
|
||||
await self._recurse_tree(tree_obj.sha, file_path, current_depth + 1)
|
||||
)
|
||||
elif tree_obj.type == "blob":
|
||||
print_if_verbose(
|
||||
self._verbose, "\t" * current_depth + f"found blob {tree_obj.path}"
|
||||
)
|
||||
if not self._check_filter_file_extensions(file_path):
|
||||
print_if_verbose(
|
||||
self._verbose,
|
||||
"\t" * current_depth + f"ignoring file {tree_obj.path} due to filter",
|
||||
)
|
||||
continue
|
||||
|
||||
blobs_and_full_paths.append((tree_obj, file_path))
|
||||
return blobs_and_full_paths
|
||||
|
||||
async def _generate_documents(
|
||||
self, blobs_and_paths: List[Tuple[GitTreeResponseModel.GitTreeObject, str]]
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Generate documents from a list of blobs and their full paths.
|
||||
|
||||
:param `blobs_and_paths`: list of tuples of
|
||||
(tree object, file's full path in the repo realtive to the root of the repo)
|
||||
:return: list of documents
|
||||
"""
|
||||
buffered_iterator = BufferedGitBlobDataIterator(
|
||||
blobs_and_paths=blobs_and_paths,
|
||||
github_client=self._github_client,
|
||||
owner=self._owner,
|
||||
repo=self._repo,
|
||||
loop=self._loop,
|
||||
buffer_size=self._concurrent_requests, # TODO: make this configurable
|
||||
verbose=self._verbose,
|
||||
)
|
||||
|
||||
documents = []
|
||||
async for blob_data, full_path in buffered_iterator:
|
||||
print_if_verbose(self._verbose, f"generating document for {full_path}")
|
||||
assert (
|
||||
blob_data.encoding == "base64"
|
||||
), f"blob encoding {blob_data.encoding} not supported"
|
||||
decoded_bytes = None
|
||||
try:
|
||||
decoded_bytes = base64.b64decode(blob_data.content)
|
||||
del blob_data.content
|
||||
except binascii.Error:
|
||||
print_if_verbose(
|
||||
self._verbose, f"could not decode {full_path} as base64"
|
||||
)
|
||||
continue
|
||||
|
||||
if self._use_parser:
|
||||
document = self._parse_supported_file(
|
||||
file_path=full_path,
|
||||
file_content=decoded_bytes,
|
||||
tree_sha=blob_data.sha,
|
||||
tree_path=full_path,
|
||||
)
|
||||
if document is not None:
|
||||
documents.append(document)
|
||||
continue
|
||||
print_if_verbose(
|
||||
self._verbose,
|
||||
f"could not parse {full_path} as a supported file type"
|
||||
+ " - falling back to decoding as utf-8 raw text",
|
||||
)
|
||||
|
||||
try:
|
||||
if decoded_bytes is None:
|
||||
raise ValueError("decoded_bytes is None")
|
||||
decoded_text = decoded_bytes.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
print_if_verbose(
|
||||
self._verbose, f"could not decode {full_path} as utf-8"
|
||||
)
|
||||
continue
|
||||
print_if_verbose(
|
||||
self._verbose,
|
||||
f"got {len(decoded_text)} characters"
|
||||
+ f"- adding to documents - {full_path}",
|
||||
)
|
||||
document = Document(
|
||||
text=decoded_text,
|
||||
doc_id=blob_data.sha,
|
||||
extra_info={
|
||||
"file_path": full_path,
|
||||
"file_name": full_path.split("/")[-1],
|
||||
},
|
||||
)
|
||||
documents.append(document)
|
||||
return documents
|
||||
|
||||
def _parse_supported_file(
|
||||
self, file_path: str, file_content: bytes, tree_sha: str, tree_path: str
|
||||
) -> Optional[Document]:
|
||||
"""
|
||||
Parse a file if it is supported by a parser.
|
||||
|
||||
:param `file_path`: path of the file in the repo
|
||||
:param `file_content`: content of the file
|
||||
:return: Document if the file is supported by a parser, None otherwise
|
||||
"""
|
||||
file_extension = get_file_extension(file_path)
|
||||
if (parser := DEFAULT_FILE_EXTRACTOR.get(file_extension)) is not None:
|
||||
parser.init_parser()
|
||||
print_if_verbose(
|
||||
self._verbose,
|
||||
f"parsing {file_path}"
|
||||
+ f"as {file_extension} with "
|
||||
+ f"{parser.__class__.__name__}",
|
||||
)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
dir=tmpdirname,
|
||||
suffix=f".{file_extension}",
|
||||
mode="w+b",
|
||||
delete=False,
|
||||
) as tmpfile:
|
||||
print_if_verbose(
|
||||
self._verbose,
|
||||
"created a temporary file"
|
||||
+ f"{tmpfile.name} for parsing {file_path}",
|
||||
)
|
||||
tmpfile.write(file_content)
|
||||
tmpfile.flush()
|
||||
tmpfile.close()
|
||||
try:
|
||||
parsed_file = parser.parse_file(pathlib.Path(tmpfile.name))
|
||||
parsed_file = "\n\n".join(parsed_file)
|
||||
except Exception as e:
|
||||
print_if_verbose(
|
||||
self._verbose, f"error while parsing {file_path}"
|
||||
)
|
||||
logger.error(
|
||||
"Error while parsing "
|
||||
+ f"{file_path} with "
|
||||
+ f"{parser.__class__.__name__}:\n{e}"
|
||||
)
|
||||
parsed_file = None
|
||||
finally:
|
||||
os.remove(tmpfile.name)
|
||||
if parsed_file is None:
|
||||
return None
|
||||
return Document(
|
||||
text=parsed_file,
|
||||
doc_id=tree_sha,
|
||||
extra_info={
|
||||
"file_path": file_path,
|
||||
"file_name": tree_path,
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
|
||||
def timeit(func: Callable) -> Callable:
|
||||
"""Time a function."""
|
||||
|
||||
def wrapper(*args: Any, **kwargs: Any) -> None:
|
||||
"""Callcuate time taken to run a function."""
|
||||
start = time.time()
|
||||
func(*args, **kwargs)
|
||||
end = time.time()
|
||||
print(f"Time taken: {end - start} seconds for {func.__name__}")
|
||||
|
||||
return wrapper
|
||||
|
||||
github_client = GithubClient(github_token=os.environ["GITHUB_TOKEN"], verbose=True)
|
||||
|
||||
reader1 = GithubRepositoryReader(
|
||||
github_client=github_client,
|
||||
owner="jerryjliu",
|
||||
repo="gpt_index",
|
||||
use_parser=False,
|
||||
verbose=True,
|
||||
filter_directories=(
|
||||
["docs"],
|
||||
GithubRepositoryReader.FilterType.INCLUDE,
|
||||
),
|
||||
filter_file_extensions=(
|
||||
[".png", ".jpg", ".jpeg", ".gif", ".svg", ".ico", "json", ".ipynb"],
|
||||
GithubRepositoryReader.FilterType.EXCLUDE,
|
||||
),
|
||||
)
|
||||
|
||||
@timeit
|
||||
def load_data_from_commit() -> None:
|
||||
"""Load data from a commit."""
|
||||
documents = reader1.load_data(
|
||||
commit_sha="22e198b3b166b5facd2843d6a62ac0db07894a13"
|
||||
)
|
||||
for document in documents:
|
||||
print(document.extra_info)
|
||||
|
||||
@timeit
|
||||
def load_data_from_branch() -> None:
|
||||
"""Load data from a branch."""
|
||||
documents = reader1.load_data(branch="main")
|
||||
for document in documents:
|
||||
print(document.extra_info)
|
||||
|
||||
input("Press enter to load github repository from branch name...")
|
||||
|
||||
load_data_from_branch()
|
||||
|
||||
# input("Press enter to load github repository from commit sha...")
|
||||
|
||||
# load_data_from_commit()
|
||||
435
loader_hub/github_repo/github_client.py
Normal file
435
loader_hub/github_repo/github_client.py
Normal file
@ -0,0 +1,435 @@
|
||||
"""
|
||||
Github API client for the GPT-Index library.
|
||||
|
||||
This module contains the Github API client for the GPT-Index library.
|
||||
It is used by the Github readers to retrieve the data from Github.
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from dataclasses_json import DataClassJsonMixin
|
||||
|
||||
|
||||
@dataclass
|
||||
class GitTreeResponseModel(DataClassJsonMixin):
|
||||
"""
|
||||
Dataclass for the response from the Github API's getTree endpoint.
|
||||
|
||||
Attributes:
|
||||
- sha (str): SHA1 checksum ID of the tree.
|
||||
- url (str): URL for the tree.
|
||||
- tree (List[GitTreeObject]): List of objects in the tree.
|
||||
- truncated (bool): Whether the tree is truncated.
|
||||
|
||||
Examples:
|
||||
>>> tree = client.get_tree("owner", "repo", "branch")
|
||||
>>> tree.sha
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class GitTreeObject(DataClassJsonMixin):
|
||||
"""
|
||||
Dataclass for the objects in the tree.
|
||||
|
||||
Attributes:
|
||||
- path (str): Path to the object.
|
||||
- mode (str): Mode of the object.
|
||||
- type (str): Type of the object.
|
||||
- sha (str): SHA1 checksum ID of the object.
|
||||
- url (str): URL for the object.
|
||||
- size (Optional[int]): Size of the object (only for blobs).
|
||||
"""
|
||||
|
||||
path: str
|
||||
mode: str
|
||||
type: str
|
||||
sha: str
|
||||
url: str
|
||||
size: Optional[int] = None
|
||||
|
||||
sha: str
|
||||
url: str
|
||||
tree: List[GitTreeObject]
|
||||
truncated: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class GitBlobResponseModel(DataClassJsonMixin):
|
||||
"""
|
||||
Dataclass for the response from the Github API's getBlob endpoint.
|
||||
|
||||
Attributes:
|
||||
- content (str): Content of the blob.
|
||||
- encoding (str): Encoding of the blob.
|
||||
- url (str): URL for the blob.
|
||||
- sha (str): SHA1 checksum ID of the blob.
|
||||
- size (int): Size of the blob.
|
||||
- node_id (str): Node ID of the blob.
|
||||
"""
|
||||
|
||||
content: str
|
||||
encoding: str
|
||||
url: str
|
||||
sha: str
|
||||
size: int
|
||||
node_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class GitCommitResponseModel(DataClassJsonMixin):
|
||||
"""
|
||||
Dataclass for the response from the Github API's getCommit endpoint.
|
||||
|
||||
Attributes:
|
||||
- tree (Tree): Tree object for the commit.
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class Commit(DataClassJsonMixin):
|
||||
"""Dataclass for the commit object in the commit. (commit.commit)."""
|
||||
|
||||
@dataclass
|
||||
class Tree(DataClassJsonMixin):
|
||||
"""
|
||||
Dataclass for the tree object in the commit.
|
||||
|
||||
Attributes:
|
||||
- sha (str): SHA for the commit
|
||||
"""
|
||||
|
||||
sha: str
|
||||
|
||||
tree: Tree
|
||||
|
||||
commit: Commit
|
||||
|
||||
|
||||
@dataclass
|
||||
class GitBranchResponseModel(DataClassJsonMixin):
|
||||
"""
|
||||
Dataclass for the response from the Github API's getBranch endpoint.
|
||||
|
||||
Attributes:
|
||||
- commit (Commit): Commit object for the branch.
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class Commit(DataClassJsonMixin):
|
||||
"""Dataclass for the commit object in the branch. (commit.commit)."""
|
||||
|
||||
@dataclass
|
||||
class Commit(DataClassJsonMixin):
|
||||
"""Dataclass for the commit object in the commit. (commit.commit.tree)."""
|
||||
|
||||
@dataclass
|
||||
class Tree(DataClassJsonMixin):
|
||||
"""
|
||||
Dataclass for the tree object in the commit.
|
||||
|
||||
Usage: commit.commit.tree.sha
|
||||
"""
|
||||
|
||||
sha: str
|
||||
|
||||
tree: Tree
|
||||
|
||||
commit: Commit
|
||||
|
||||
commit: Commit
|
||||
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class BaseGithubClient(Protocol):
|
||||
def get_all_endpoints(self) -> Dict[str, str]:
|
||||
...
|
||||
|
||||
async def request(
|
||||
self,
|
||||
endpoint: str,
|
||||
method: str,
|
||||
headers: Dict[str, Any] = {},
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
...
|
||||
|
||||
async def get_tree(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
tree_sha: str,
|
||||
) -> GitTreeResponseModel:
|
||||
...
|
||||
|
||||
async def get_blob(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
file_sha: str,
|
||||
) -> GitBlobResponseModel:
|
||||
...
|
||||
|
||||
async def get_commit(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
commit_sha: str,
|
||||
) -> GitCommitResponseModel:
|
||||
...
|
||||
|
||||
async def get_branch(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
branch_name: str,
|
||||
) -> GitBranchResponseModel:
|
||||
...
|
||||
|
||||
|
||||
class GithubClient:
|
||||
"""
|
||||
An asynchronous client for interacting with the Github API.
|
||||
|
||||
This client is used for making API requests to Github.
|
||||
It provides methods for accessing the Github API endpoints.
|
||||
The client requires a Github token for authentication,
|
||||
which can be passed as an argument or set as an environment variable.
|
||||
If no Github token is provided, the client will raise a ValueError.
|
||||
|
||||
Examples:
|
||||
>>> client = GithubClient("my_github_token")
|
||||
>>> branch_info = client.get_branch("owner", "repo", "branch")
|
||||
"""
|
||||
|
||||
DEFAULT_BASE_URL = "https://api.github.com"
|
||||
DEFAULT_API_VERSION = "2022-11-28"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
github_token: Optional[str] = None,
|
||||
base_url: str = DEFAULT_BASE_URL,
|
||||
api_version: str = DEFAULT_API_VERSION,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the GithubClient.
|
||||
|
||||
Args:
|
||||
- github_token (str): Github token for authentication.
|
||||
If not provided, the client will try to get it from
|
||||
the GITHUB_TOKEN environment variable.
|
||||
- base_url (str): Base URL for the Github API
|
||||
(defaults to "https://api.github.com").
|
||||
- api_version (str): Github API version (defaults to "2022-11-28").
|
||||
|
||||
Raises:
|
||||
ValueError: If no Github token is provided.
|
||||
"""
|
||||
if github_token is None:
|
||||
github_token = os.getenv("GITHUB_TOKEN")
|
||||
if github_token is None:
|
||||
raise ValueError(
|
||||
"Please provide a Github token. "
|
||||
+ "You can do so by passing it as an argument to the GithubReader,"
|
||||
+ "or by setting the GITHUB_TOKEN environment variable."
|
||||
)
|
||||
|
||||
self._base_url = base_url
|
||||
self._api_version = api_version
|
||||
self._verbose = verbose
|
||||
|
||||
self._endpoints = {
|
||||
"getTree": "/repos/{owner}/{repo}/git/trees/{tree_sha}",
|
||||
"getBranch": "/repos/{owner}/{repo}/branches/{branch}",
|
||||
"getBlob": "/repos/{owner}/{repo}/git/blobs/{file_sha}",
|
||||
"getCommit": "/repos/{owner}/{repo}/commits/{commit_sha}",
|
||||
}
|
||||
|
||||
self._headers = {
|
||||
"Accept": "application/vnd.github+json",
|
||||
"Authorization": f"Bearer {github_token}",
|
||||
"X-GitHub-Api-Version": f"{self._api_version}",
|
||||
}
|
||||
|
||||
def get_all_endpoints(self) -> Dict[str, str]:
|
||||
"""Get all available endpoints."""
|
||||
return {**self._endpoints}
|
||||
|
||||
async def request(
|
||||
self,
|
||||
endpoint: str,
|
||||
method: str,
|
||||
headers: Dict[str, Any] = {},
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""
|
||||
Make an API request to the Github API.
|
||||
|
||||
This method is used for making API requests to the Github API.
|
||||
It is used internally by the other methods in the client.
|
||||
|
||||
Args:
|
||||
- `endpoint (str)`: Name of the endpoint to make the request to.
|
||||
- `method (str)`: HTTP method to use for the request.
|
||||
- `headers (dict)`: HTTP headers to include in the request.
|
||||
- `**kwargs`: Keyword arguments to pass to the endpoint URL.
|
||||
|
||||
Returns:
|
||||
- `response (httpx.Response)`: Response from the API request.
|
||||
|
||||
Raises:
|
||||
- ImportError: If the `httpx` library is not installed.
|
||||
- httpx.HTTPError: If the API request fails.
|
||||
|
||||
Examples:
|
||||
>>> response = client.request("getTree", "GET",
|
||||
owner="owner", repo="repo",
|
||||
tree_sha="tree_sha")
|
||||
"""
|
||||
try:
|
||||
import httpx
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install httpx to use the GithubRepositoryReader. "
|
||||
"You can do so by running `pip install httpx`."
|
||||
)
|
||||
|
||||
_headers = {**self._headers, **headers}
|
||||
|
||||
_client: httpx.AsyncClient
|
||||
async with httpx.AsyncClient(
|
||||
headers=_headers, base_url=self._base_url
|
||||
) as _client:
|
||||
try:
|
||||
response = await _client.request(
|
||||
method, url=self._endpoints[endpoint].format(**kwargs)
|
||||
)
|
||||
except httpx.HTTPError as excp:
|
||||
print(f"HTTP Exception for {excp.request.url} - {excp}")
|
||||
raise excp
|
||||
return response
|
||||
|
||||
async def get_branch(
|
||||
self, owner: str, repo: str, branch: str
|
||||
) -> GitBranchResponseModel:
|
||||
"""
|
||||
Get information about a branch. (Github API endpoint: getBranch).
|
||||
|
||||
Args:
|
||||
- `owner (str)`: Owner of the repository.
|
||||
- `repo (str)`: Name of the repository.
|
||||
- `branch (str)`: Name of the branch.
|
||||
|
||||
Returns:
|
||||
- `branch_info (GitBranchResponseModel)`: Information about the branch.
|
||||
|
||||
Examples:
|
||||
>>> branch_info = client.get_branch("owner", "repo", "branch")
|
||||
"""
|
||||
return GitBranchResponseModel.from_json(
|
||||
(
|
||||
await self.request(
|
||||
"getBranch", "GET", owner=owner, repo=repo, branch=branch
|
||||
)
|
||||
).text
|
||||
)
|
||||
|
||||
async def get_tree(
|
||||
self, owner: str, repo: str, tree_sha: str
|
||||
) -> GitTreeResponseModel:
|
||||
"""
|
||||
Get information about a tree. (Github API endpoint: getTree).
|
||||
|
||||
Args:
|
||||
- `owner (str)`: Owner of the repository.
|
||||
- `repo (str)`: Name of the repository.
|
||||
- `tree_sha (str)`: SHA of the tree.
|
||||
|
||||
Returns:
|
||||
- `tree_info (GitTreeResponseModel)`: Information about the tree.
|
||||
|
||||
Examples:
|
||||
>>> tree_info = client.get_tree("owner", "repo", "tree_sha")
|
||||
"""
|
||||
return GitTreeResponseModel.from_json(
|
||||
(
|
||||
await self.request(
|
||||
"getTree", "GET", owner=owner, repo=repo, tree_sha=tree_sha
|
||||
)
|
||||
).text
|
||||
)
|
||||
|
||||
async def get_blob(
|
||||
self, owner: str, repo: str, file_sha: str
|
||||
) -> GitBlobResponseModel:
|
||||
"""
|
||||
Get information about a blob. (Github API endpoint: getBlob).
|
||||
|
||||
Args:
|
||||
- `owner (str)`: Owner of the repository.
|
||||
- `repo (str)`: Name of the repository.
|
||||
- `file_sha (str)`: SHA of the file.
|
||||
|
||||
Returns:
|
||||
- `blob_info (GitBlobResponseModel)`: Information about the blob.
|
||||
|
||||
Examples:
|
||||
>>> blob_info = client.get_blob("owner", "repo", "file_sha")
|
||||
"""
|
||||
return GitBlobResponseModel.from_json(
|
||||
(
|
||||
await self.request(
|
||||
"getBlob", "GET", owner=owner, repo=repo, file_sha=file_sha
|
||||
)
|
||||
).text
|
||||
)
|
||||
|
||||
async def get_commit(
|
||||
self, owner: str, repo: str, commit_sha: str
|
||||
) -> GitCommitResponseModel:
|
||||
"""
|
||||
Get information about a commit. (Github API endpoint: getCommit).
|
||||
|
||||
Args:
|
||||
- `owner (str)`: Owner of the repository.
|
||||
- `repo (str)`: Name of the repository.
|
||||
- `commit_sha (str)`: SHA of the commit.
|
||||
|
||||
Returns:
|
||||
- `commit_info (GitCommitResponseModel)`: Information about the commit.
|
||||
|
||||
Examples:
|
||||
>>> commit_info = client.get_commit("owner", "repo", "commit_sha")
|
||||
"""
|
||||
return GitCommitResponseModel.from_json(
|
||||
(
|
||||
await self.request(
|
||||
"getCommit", "GET", owner=owner, repo=repo, commit_sha=commit_sha
|
||||
)
|
||||
).text
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
async def main() -> None:
|
||||
"""Test the GithubClient."""
|
||||
client = GithubClient()
|
||||
response = await client.get_tree(
|
||||
owner="ahmetkca", repo="CommitAI", tree_sha="with-body"
|
||||
)
|
||||
|
||||
for obj in response.tree:
|
||||
if obj.type == "blob":
|
||||
print(obj.path)
|
||||
print(obj.sha)
|
||||
blob_response = await client.get_blob(
|
||||
owner="ahmetkca", repo="CommitAI", file_sha=obj.sha
|
||||
)
|
||||
print(blob_response.content)
|
||||
|
||||
asyncio.run(main())
|
||||
1
loader_hub/github_repo/requirements.txt
Normal file
1
loader_hub/github_repo/requirements.txt
Normal file
@ -0,0 +1 @@
|
||||
httpx
|
||||
170
loader_hub/github_repo/utils.py
Normal file
170
loader_hub/github_repo/utils.py
Normal file
@ -0,0 +1,170 @@
|
||||
"""
|
||||
Github readers utils.
|
||||
|
||||
This module contains utility functions for the Github readers.
|
||||
"""
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple
|
||||
|
||||
from gpt_index.readers.github_readers.github_api_client import (
|
||||
GitBlobResponseModel,
|
||||
GithubClient,
|
||||
GitTreeResponseModel,
|
||||
)
|
||||
|
||||
|
||||
def print_if_verbose(verbose: bool, message: str) -> None:
|
||||
"""Log message if verbose is True."""
|
||||
if verbose:
|
||||
print(message)
|
||||
|
||||
|
||||
def get_file_extension(filename: str) -> str:
|
||||
"""Get file extension."""
|
||||
return f".{os.path.splitext(filename)[1][1:].lower()}"
|
||||
|
||||
|
||||
class BufferedAsyncIterator(ABC):
|
||||
"""
|
||||
Base class for buffered async iterators.
|
||||
|
||||
This class is to be used as a base class for async iterators
|
||||
that need to buffer the results of an async operation.
|
||||
The async operation is defined in the _fill_buffer method.
|
||||
The _fill_buffer method is called when the buffer is empty.
|
||||
"""
|
||||
|
||||
def __init__(self, buffer_size: int):
|
||||
"""
|
||||
Initialize params.
|
||||
|
||||
Args:
|
||||
- `buffer_size (int)`: Size of the buffer.
|
||||
It is also the number of items that will
|
||||
be retrieved from the async operation at once.
|
||||
see _fill_buffer. Defaults to 2. Setting it to 1
|
||||
will result in the same behavior as a synchronous iterator.
|
||||
"""
|
||||
self._buffer_size = buffer_size
|
||||
self._buffer: List[Tuple[GitBlobResponseModel, str]] = []
|
||||
self._index = 0
|
||||
|
||||
@abstractmethod
|
||||
async def _fill_buffer(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def __aiter__(self) -> "BufferedAsyncIterator":
|
||||
"""Return the iterator object."""
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> Tuple[GitBlobResponseModel, str]:
|
||||
"""
|
||||
Get next item.
|
||||
|
||||
Returns:
|
||||
- `item (Tuple[GitBlobResponseModel, str])`: Next item.
|
||||
|
||||
Raises:
|
||||
- `StopAsyncIteration`: If there are no more items.
|
||||
"""
|
||||
if not self._buffer:
|
||||
await self._fill_buffer()
|
||||
|
||||
if not self._buffer:
|
||||
raise StopAsyncIteration
|
||||
|
||||
item = self._buffer.pop(0)
|
||||
self._index += 1
|
||||
return item
|
||||
|
||||
|
||||
class BufferedGitBlobDataIterator(BufferedAsyncIterator):
|
||||
"""
|
||||
Buffered async iterator for Git blobs.
|
||||
|
||||
This class is an async iterator that buffers the results of the get_blob operation.
|
||||
It is used to retrieve the contents of the files in a Github repository.
|
||||
getBlob endpoint supports up to 100 megabytes of content for blobs.
|
||||
This concrete implementation of BufferedAsyncIterator allows you to lazily retrieve
|
||||
the contents of the files in a Github repository.
|
||||
Otherwise you would have to retrieve all the contents of
|
||||
the files in the repository at once, which would
|
||||
be problematic if the repository is large.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
blobs_and_paths: List[Tuple[GitTreeResponseModel.GitTreeObject, str]],
|
||||
github_client: GithubClient,
|
||||
owner: str,
|
||||
repo: str,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
buffer_size: int,
|
||||
verbose: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize params.
|
||||
|
||||
Args:
|
||||
- blobs_and_paths (List[Tuple[GitTreeResponseModel.GitTreeObject, str]]):
|
||||
List of tuples containing the blob and the path of the file.
|
||||
- github_client (GithubClient): Github client.
|
||||
- owner (str): Owner of the repository.
|
||||
- repo (str): Name of the repository.
|
||||
- loop (asyncio.AbstractEventLoop): Event loop.
|
||||
- buffer_size (int): Size of the buffer.
|
||||
"""
|
||||
super().__init__(buffer_size)
|
||||
self._blobs_and_paths = blobs_and_paths
|
||||
self._github_client = github_client
|
||||
self._owner = owner
|
||||
self._repo = repo
|
||||
self._verbose = verbose
|
||||
if loop is None:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop is None:
|
||||
raise ValueError("No event loop found")
|
||||
|
||||
async def _fill_buffer(self) -> None:
|
||||
"""
|
||||
Fill the buffer with the results of the get_blob operation.
|
||||
|
||||
The get_blob operation is called for each blob in the blobs_and_paths list.
|
||||
The blobs are retrieved in batches of size buffer_size.
|
||||
"""
|
||||
del self._buffer[:]
|
||||
self._buffer = []
|
||||
start = self._index
|
||||
end = min(start + self._buffer_size, len(self._blobs_and_paths))
|
||||
|
||||
if start >= end:
|
||||
return
|
||||
|
||||
if self._verbose:
|
||||
start_t = time.time()
|
||||
results: List[GitBlobResponseModel] = await asyncio.gather(
|
||||
*[
|
||||
self._github_client.get_blob(self._owner, self._repo, blob.sha)
|
||||
for blob, _ in self._blobs_and_paths[
|
||||
start:end
|
||||
] # TODO: use batch_size instead of buffer_size for concurrent requests
|
||||
]
|
||||
)
|
||||
if self._verbose:
|
||||
end_t = time.time()
|
||||
blob_names_and_sizes = [
|
||||
(blob.path, blob.size) for blob, _ in self._blobs_and_paths[start:end]
|
||||
]
|
||||
print(
|
||||
"Time to get blobs ("
|
||||
+ f"{blob_names_and_sizes}"
|
||||
+ f"): {end_t - start_t:.2f} seconds"
|
||||
)
|
||||
|
||||
self._buffer = [
|
||||
(result, path)
|
||||
for result, (_, path) in zip(results, self._blobs_and_paths[start:end])
|
||||
]
|
||||
@ -26,7 +26,11 @@
|
||||
"CJKPDFReader": {
|
||||
"id": "file/cjk_pdf",
|
||||
"author": "JiroShimaya",
|
||||
"keywords": ["Japanese", "Chinese", "Korean"]
|
||||
"keywords": [
|
||||
"Japanese",
|
||||
"Chinese",
|
||||
"Korean"
|
||||
]
|
||||
},
|
||||
"DocxReader": {
|
||||
"id": "file/docx",
|
||||
@ -39,7 +43,10 @@
|
||||
"ImageReader": {
|
||||
"id": "file/image",
|
||||
"author": "ravi03071991",
|
||||
"keywords": ["invoice", "receipt"]
|
||||
"keywords": [
|
||||
"invoice",
|
||||
"receipt"
|
||||
]
|
||||
},
|
||||
"EpubReader": {
|
||||
"id": "file/epub",
|
||||
@ -68,17 +75,30 @@
|
||||
"BeautifulSoupWebReader": {
|
||||
"id": "web/beautiful_soup_web",
|
||||
"author": "thejessezhang",
|
||||
"keywords": ["substack", "readthedocs", "documentation"]
|
||||
"keywords": [
|
||||
"substack",
|
||||
"readthedocs",
|
||||
"documentation"
|
||||
]
|
||||
},
|
||||
"RssReader": {
|
||||
"id": "web/rss",
|
||||
"author": "bborn",
|
||||
"keywords": ["feed", "rss", "atom"]
|
||||
"keywords": [
|
||||
"feed",
|
||||
"rss",
|
||||
"atom"
|
||||
]
|
||||
},
|
||||
"DatabaseReader": {
|
||||
"id": "database",
|
||||
"author": "kevinqz",
|
||||
"keywords": ["sql", "postgres", "snowflake", "aws rds"]
|
||||
"keywords": [
|
||||
"sql",
|
||||
"postgres",
|
||||
"snowflake",
|
||||
"aws rds"
|
||||
]
|
||||
},
|
||||
"DiscordReader": {
|
||||
"id": "discord",
|
||||
@ -154,22 +174,39 @@
|
||||
"UnstructuredReader": {
|
||||
"id": "file/unstructured",
|
||||
"author": "thejessezhang",
|
||||
"keywords": ["sec", "html", "eml", "10k", "10q", "unstructured.io"]
|
||||
"keywords": [
|
||||
"sec",
|
||||
"html",
|
||||
"eml",
|
||||
"10k",
|
||||
"10q",
|
||||
"unstructured.io"
|
||||
]
|
||||
},
|
||||
"KnowledgeBaseWebReader": {
|
||||
"id": "web/knowledge_base",
|
||||
"author": "jasonwcfan",
|
||||
"keywords": ["documentation"]
|
||||
"keywords": [
|
||||
"documentation"
|
||||
]
|
||||
},
|
||||
"S3Reader": {
|
||||
"id": "s3",
|
||||
"author": "thejessezhang",
|
||||
"keywords": ["aws s3", "bucket", "amazon web services"]
|
||||
"keywords": [
|
||||
"aws s3",
|
||||
"bucket",
|
||||
"amazon web services"
|
||||
]
|
||||
},
|
||||
"RemoteReader": {
|
||||
"id": "remote",
|
||||
"author": "thejessezhang",
|
||||
"keywords": ["hosted", "url", "gutenberg"]
|
||||
"keywords": [
|
||||
"hosted",
|
||||
"url",
|
||||
"gutenberg"
|
||||
]
|
||||
},
|
||||
"RemoteDepthReader": {
|
||||
"id": "remote_depth",
|
||||
@ -183,12 +220,18 @@
|
||||
"DadJokesReader": {
|
||||
"id": "dad_jokes",
|
||||
"author": "sidu",
|
||||
"keywords": ["jokes", "dad jokes"]
|
||||
"keywords": [
|
||||
"jokes",
|
||||
"dad jokes"
|
||||
]
|
||||
},
|
||||
"WhatsappChatLoader": {
|
||||
"id": "whatsapp",
|
||||
"author": "batmanscode",
|
||||
"keywords": ["whatsapp", "chat"]
|
||||
"keywords": [
|
||||
"whatsapp",
|
||||
"chat"
|
||||
]
|
||||
},
|
||||
"BilibiliTranscriptReader": {
|
||||
"id": "bilibili",
|
||||
@ -197,16 +240,45 @@
|
||||
"RedditReader": {
|
||||
"id": "reddit",
|
||||
"author": "vanessahlyan",
|
||||
"keywords": ["reddit", "subreddit", "search", "comments"]
|
||||
"keywords": [
|
||||
"reddit",
|
||||
"subreddit",
|
||||
"search",
|
||||
"comments"
|
||||
]
|
||||
},
|
||||
"MemosReader": {
|
||||
"id": "memos",
|
||||
"author": "bubu",
|
||||
"keywords": ["memos", "note"]
|
||||
"keywords": [
|
||||
"memos",
|
||||
"note"
|
||||
]
|
||||
},
|
||||
"SpotifyReader": {
|
||||
"id": "spotify",
|
||||
"author": "ong",
|
||||
"keywords": [
|
||||
"spotify",
|
||||
"music"
|
||||
]
|
||||
},
|
||||
"GithubRepositoryReader": {
|
||||
"id": "github_repo",
|
||||
"author": "ahmetkca",
|
||||
"keywords": [
|
||||
"github",
|
||||
"repository",
|
||||
"git",
|
||||
"code",
|
||||
"source code",
|
||||
"placeholder"
|
||||
],
|
||||
"extra_files": [
|
||||
"github_client.py",
|
||||
"utils.py",
|
||||
"__init__.py"
|
||||
]
|
||||
"keywords": ["spotify", "music"]
|
||||
},
|
||||
"RDFReader": {
|
||||
@ -214,4 +286,4 @@
|
||||
"author": "mommi84",
|
||||
"keywords": ["rdf", "n-triples", "graph", "knowledge graph"]
|
||||
}
|
||||
}
|
||||
}
|
||||
104
tests/test_github_reader.py
Normal file
104
tests/test_github_reader.py
Normal file
@ -0,0 +1,104 @@
|
||||
from typing import List, Tuple
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
# Skip by default due to network request.
|
||||
# Remove this to test changes to GithubRepositoryReader.
|
||||
pytest.skip()
|
||||
|
||||
from loader_hub.github_repo import GithubRepositoryReader, GithubClient
|
||||
|
||||
from gpt_index import Document
|
||||
|
||||
import os
|
||||
import pytest
|
||||
import base64
|
||||
|
||||
from loader_hub.github_repo import GithubClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def github_client():
|
||||
return GithubClient(
|
||||
github_token=os.getenv("GITHUB_API_TOKEN"),
|
||||
verbose= True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_github_client(github_client):
|
||||
owner = "emptycrown"
|
||||
repo = "llama-hub"
|
||||
branch = "main"
|
||||
commit_sha = "0cd691322e5244b48b68e3588d1343eb53f3a112" # Points to Add spotify reader, https://github.com/emptycrown/llama-hub/commit/0cd691322e5244b48b68e3588d1343eb53f3a112
|
||||
|
||||
# test get_branch
|
||||
branch_data = await github_client.get_branch(owner, repo, branch)
|
||||
assert branch_data.name == branch
|
||||
assert branch_data._links.self == f"https://api.github.com/repos/{owner}/{repo}/branches/{branch}", "Branch self link is incorrect"
|
||||
assert branch_data._links.html == f"https://github.com/{owner}/{repo}/tree/{branch}", "Branch html link is incorrect"
|
||||
|
||||
# test get_commit
|
||||
commit_data = await github_client.get_commit(owner, repo, commit_sha)
|
||||
assert commit_data.sha == commit_sha, "Commit sha is incorrect"
|
||||
assert commit_data.url == f"https://api.github.com/repos/{owner}/{repo}/commits/{commit_sha}", "Commit url is incorrect"
|
||||
|
||||
# test get_tree
|
||||
tree_data = await github_client.get_tree(owner, repo, commit_data.commit.tree.sha)
|
||||
assert tree_data.url == f"https://api.github.com/repos/{owner}/{repo}/git/trees/{commit_data.commit.tree.sha}", "Tree url is incorrect"
|
||||
assert tree_data.sha == commit_data.commit.tree.sha, "Tree sha is incorrect"
|
||||
print(tree_data.tree[0].sha)
|
||||
assert 1 == 1
|
||||
|
||||
# test get_blob
|
||||
expected_files_in_first_depth_of_the_tree: List[Tuple[str, str]] = [
|
||||
("test_requirements.txt", "blob"),
|
||||
("README.md", "blob"),
|
||||
("Makefile", "blob"),
|
||||
(".gitignore", "blob"),
|
||||
("tests", "tree"),
|
||||
("loader_hub", "tree"),
|
||||
(".github", "tree"),
|
||||
]
|
||||
# check if the first depth of the tree has the expected files. All the expected files should be in the first depth of the tree and vice versa
|
||||
assert len(tree_data.tree) == len(expected_files_in_first_depth_of_the_tree), "The number of files in the first depth of the tree is incorrect"
|
||||
for file in expected_files_in_first_depth_of_the_tree:
|
||||
assert file in [(tree_file.path, tree_file.type) for tree_file in tree_data.tree], f"{file} is not in the first depth of the tree"
|
||||
# checking the opposite
|
||||
for tree_obj in tree_data.tree:
|
||||
assert (tree_obj.path, tree_obj.type) in expected_files_in_first_depth_of_the_tree, f"{tree_obj.path} is not in the expected files"
|
||||
|
||||
# find test_reqirements.txt in the tree
|
||||
test_requirements_txt = [tree_obj for tree_obj in tree_data.tree if tree_obj.path == "test_requirements.txt"][0]
|
||||
|
||||
# test get_blob
|
||||
blob_data = await github_client.get_blob(owner, repo, test_requirements_txt.sha)
|
||||
assert blob_data.encoding == "base64", "Blob encoding is incorrect"
|
||||
assert blob_data.url == f"https://api.github.com/repos/{owner}/{repo}/git/blobs/{test_requirements_txt.sha}", "Blob url is incorrect"
|
||||
assert blob_data.sha == test_requirements_txt.sha, "Blob sha is incorrect"
|
||||
|
||||
# decode blob content base64-decoded string to utf-8
|
||||
decoded_blob_content = base64.b64decode(blob_data.content).decode("utf-8")
|
||||
|
||||
expected_decoded_blob_content = """
|
||||
|
||||
# For testing
|
||||
pytest==7.2.1
|
||||
pytest-dotenv==0.5.2
|
||||
# TODO: remove gpt_index after migration
|
||||
https://github.com/jerryjliu/gpt_index/archive/master.zip
|
||||
|
||||
llama-index
|
||||
|
||||
# For linting
|
||||
# linting stubs
|
||||
types-requests==2.28.11.8
|
||||
# formatting
|
||||
black==22.12.0
|
||||
isort==5.11.4
|
||||
"""
|
||||
# check if the decoded blob content is correct
|
||||
for dbc in zip(filter( lambda x: x != "", decoded_blob_content.splitlines()), filter( lambda x: x != "", expected_decoded_blob_content.splitlines())):
|
||||
assert dbc[0] == dbc[1], f"{dbc[0]} is not equal to {dbc[1]}"
|
||||
Loading…
x
Reference in New Issue
Block a user