Add GitHub Repository Reader (#34)

* add github repository, test a new way to download loader

* test imports when downloaded from gpt_index

* Refactor(Github Repo): Move github_client and utils to modules

* Moved github_client.py and utils.py from loader_hub/github_repo to modules/github_repo
* Updated import statements in base.py to reflect the new location

* temp

* Refactor(GithubRepositoryReader): Add github_client argument

- Add github_client argument to GithubRepositoryReader constructor
- Set default value for github_client argument
- Update docstring to reflect changes

* Refactor(Github Repo): Update init file

- Remove imports of base, github_client and utils
- Add imports of GithubRepositoryReader and GithubClient
- Update __all__ to include the new imports

* Fix(library): Update library.json

- Updated library.json to include __init__.py file

* Refactor(GithubRepositoryReader): Add filter for directories and files

- Add filter for directories and files in GithubRepositoryReader
- Ignore directories and files that do not pass the filter
- Print out if directory or file is ignored due to filter

* Refactor(BaseReader): Check filter files

- Refactor `_check_filter_files` to `_check_filter_file_extensions` in `BaseReader`
- Ignoring files due to filter

* Docs(FilterType): Add documentation for FilterType enum

- Add documentation for FilterType enum
- Explain what the enum is used for
- Describe the attributes of the enum

* Add(GPT Index): Add GPT Index example

Add GPT Index example to README
- Set OPENAI_API_KEY environment variable
- Download GithubRepositoryReader module
- Create GithubClient and GithubRepositoryReader
- Load data from Github Repository
- Create GPTSimpleVectorIndex
- Query the index

* Add(GPT Index): Add GPT Index example

Add GPT Index example to README
- Set OPENAI_API_KEY environment variable
- Download GithubRepositoryReader module
- Create GithubClient and GithubRepositoryReader
- Load data from Github Repository
- Create GPTSimpleVectorIndex
- Query the index

* Add(GPT Index): Add GPT Index example

Add GPT Index example to README
- Set OPENAI_API_KEY environment variable
- Download GithubRepositoryReader module
- Create GithubClient and GithubRepositoryReader
- Load data from Github Repository
- Create GPTSimpleVectorIndex
- Query the index

* change the import path for extras

* change import path for extra files to absolute

* Add test for GithubClient currently not using mocks which is not ideal

* Update test_github_reader.py

* Update test_github_reader.py

---------

Co-authored-by: Jesse Zhang <jessetanzhang@gmail.com>
This commit is contained in:
ahmetkca 2023-02-25 02:41:48 -05:00 committed by GitHub
parent 457e7888e9
commit 5a27264db1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 1420 additions and 15 deletions

View File

@ -1 +1 @@
"""Init file."""
"""Init file."""

View File

@ -0,0 +1,83 @@
# Github Repository Loader
This loader takes in `owner`, `repo`, `branch`, `commit` and other optional parameters such as for filtering dicrectories or only allowing some files with given extensions etc. It then fetches all the contents of the GitHub repository.
As a prerequisite, you will need to generate a person access token. See [here](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token) for instructions.
## Usage
To use this loader, you simply need to pass in the `owner` and `repo` and either `branch` or `commit` for example, you can `owner = jerryjliu` and `repo = gpt_index` and also either branch or commit `branch = main` or `commit = a6c89159bf8e7086bea2f4305cff3f0a4102e370`
```python
import os
from gpt_index import download_loader
download_loader("GithubRepositoryReader")
from modules.github_repo import GithubRepositoryReader, GithubClient
github_client = GithubClient(os.getenv("GITHUB_TOKEN"))
loader = GithubRepositoryReader(
github_client,
owner = "jerryjliu",
repo = "gpt_index",
filter_directories = (["gpt_index", "docs"], GithubRepositoryReader.FilterType.INCLUDE),
filter_file_extensions = ([".py"], GithubRepositoryReader.FilterType.INCLUDE),
verbose = True,
concurrent_requests = 10,
)
docs_branch = loader.load_data(branch="main")
docs_commit = loader.load_data(commit="a6c89159bf8e7086bea2f4305cff3f0a4102e370")
for doc in docs:
print(doc.extra_info)
```
## Examples
This loader designed to be used as a way to load data into [GPT Index](https://github.com/jerryjliu/gpt_index/tree/main/gpt_index) and/or subsequently used as a Tool in a [LangChain](https://github.com/hwchase17/langchain) Agent.
### GPT Index
```python
import pickle
import os
assert (
os.getenv("OPENAI_API_KEY") is not None
), "Please set the OPENAI_API_KEY environment variable."
from gpt_index import download_loader
download_loader("GithubRepositoryReader")
from modules.github_repo import GithubClient, GithubRepositoryReader
docs = None
docs = None
if os.path.exists("docs.pkl"):
with open("docs.pkl", "rb") as f:
docs = pickle.load(f)
if docs is None:
github_client = GithubClient(os.getenv("GITHUB_TOKEN"))
loader = GithubRepositoryReader(
github_client,
owner = "jerryjliu",
repo = "gpt_index",
filter_directories = (["gpt_index", "docs"], GithubRepositoryReader.FilterType.INCLUDE),
filter_file_extensions = ([".py"], GithubRepositoryReader.FilterType.INCLUDE),
verbose = True,
concurrent_requests = 10,
)
docs = loader.load_data(branch="main")
with open("docs.pkl", "wb") as f:
pickle.dump(docs, f)
index = GPTSimpleVectorIndex(docs)
index.query("Explain each GPTIndex class?")
```

View File

@ -0,0 +1,6 @@
"""Init file."""
from .base import GithubRepositoryReader
from .github_client import GithubClient
__all__ = ["GithubRepositoryReader", "GithubClient"]

View File

@ -0,0 +1,534 @@
"""
Github repository reader.
Retrieves the contents of a Github repository and returns a list of documents.
The documents are either the contents of the files in the repository or
the text extracted from the files using the parser.
"""
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from dataclasses_json import DataClassJsonMixin
import asyncio
import base64
import binascii
import logging
import os
import pathlib
import tempfile
import enum
from typing import Any, Callable, List, Optional, Tuple
from gpt_index.readers.base import BaseReader
from gpt_index.readers.file.base import DEFAULT_FILE_EXTRACTOR
from gpt_index.readers.llamahub_modules.github_repo.github_client import (
BaseGithubClient,
GitBranchResponseModel,
GitCommitResponseModel,
GithubClient,
GitTreeResponseModel,
)
from gpt_index.readers.llamahub_modules.github_repo.utils import (
BufferedGitBlobDataIterator,
print_if_verbose,
get_file_extension,
)
from gpt_index.readers.schema.base import Document
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class GithubRepositoryReader(BaseReader):
"""
Github repository reader.
Retrieves the contents of a Github repository and returns a list of documents.
The documents are either the contents of the files in the repository or the text
extracted from the files using the parser.
Examples:
>>> reader = GithubRepositoryReader("owner", "repo")
>>> branch_documents = reader.load_data(branch="branch")
>>> commit_documents = reader.load_data(commit_sha="commit_sha")
"""
class FilterType(enum.Enum):
"""
Filter type.
Used to determine whether the filter is inclusive or exclusive.
Attributes:
- EXCLUDE: Exclude the files in the directories or with the extensions.
- INCLUDE: Include only the files in the directories or with the extensions.
"""
EXCLUDE = enum.auto()
INCLUDE = enum.auto()
def __init__(
self,
github_client: BaseGithubClient,
owner: str,
repo: str,
use_parser: bool = True,
verbose: bool = False,
concurrent_requests: int = 5,
filter_directories: Optional[Tuple[List[str], FilterType]] = None,
filter_file_extensions: Optional[Tuple[List[str], FilterType]] = None,
):
"""
Initialize params.
Args:
- github_client (BaseGithubClient): Github client.
- owner (str): Owner of the repository.
- repo (str): Name of the repository.
- use_parser (bool): Whether to use the parser to extract
the text from the files.
- verbose (bool): Whether to print verbose messages.
- concurrent_requests (int): Number of concurrent requests to
make to the Github API.
- filter_directories (Optional[Tuple[List[str], FilterType]]): Tuple
containing a list of directories and a FilterType. If the FilterType
is INCLUDE, only the files in the directories in the list will be
included. If the FilterType is EXCLUDE, the files in the directories
in the list will be excluded.
- filter_file_extensions (Optional[Tuple[List[str], FilterType]]): Tuple
containing a list of file extensions and a FilterType. If the
FilterType is INCLUDE, only the files with the extensions in the list
will be included. If the FilterType is EXCLUDE, the files with the
extensions in the list will be excluded.
Raises:
- `ValueError`: If the github_token is not provided and
the GITHUB_TOKEN environment variable is not set.
"""
super().__init__()
self._owner = owner
self._repo = repo
self._use_parser = use_parser
self._verbose = verbose
self._concurrent_requests = concurrent_requests
self._filter_directories = filter_directories
self._filter_file_extensions = filter_file_extensions
# Set up the event loop
try:
self._loop = asyncio.get_running_loop()
except RuntimeError:
# If there is no running loop, create a new one
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self._github_client = github_client
def _check_filter_directories(self, tree_obj_path: str) -> bool:
"""
Check if a tree object should be allowed based on the directories.
:param `tree_obj_path`: path of the tree object i.e. 'gpt_index/readers'
:return: True if the tree object should be allowed, False otherwise
"""
filter_directories, filter_type = self._filter_directories
print_if_verbose(
self._verbose,
f"Checking {tree_obj_path} whether to {filter_type} it"
+ f" based on the filter directories: {filter_directories}",
)
if filter_type == self.FilterType.EXCLUDE:
return not any(
tree_obj_path.startswith(directory)
or directory.startswith(tree_obj_path)
for directory in filter_directories
)
elif filter_type == self.FilterType.INCLUDE:
return any(
tree_obj_path.startswith(directory)
or directory.startswith(tree_obj_path)
for directory in filter_directories
)
else:
raise ValueError(
f"Unknown filter type: {filter_type}. "
"Please use either 'ignore' or 'include'."
)
def _check_filter_file_extensions(self, tree_obj_path: str) -> bool:
"""
Check if a tree object should be allowed based on the file extensions.
:param `tree_obj_path`: path of the tree object i.e. 'gpt_index/indices'
:return: True if the tree object should be allowed, False otherwise
"""
filter_file_extensions, filter_type = self._filter_file_extensions
print_if_verbose(
self._verbose,
f"Checking {tree_obj_path} whether to {filter_type} it"
+ f" based on the filter file extensions: {filter_file_extensions}",
)
if filter_type == self.FilterType.EXCLUDE:
return get_file_extension(tree_obj_path) not in filter_file_extensions
elif filter_type == self.FilterType.INCLUDE:
return get_file_extension(tree_obj_path) in filter_file_extensions
else:
raise ValueError(
f"Unknown filter type: {filter_type}. "
"Please use either 'ignore' or 'include'."
)
def _allow_tree_obj(self, tree_obj_path: str) -> bool:
"""
Check if a tree object should be allowed.
:param `tree_obj_path`: path of the tree object
:return: True if the tree object should be allowed, False otherwise
"""
if self._filter_directories is not None:
return self._check_filter_directories(tree_obj_path)
if self._filter_file_extensions is not None:
return self._check_filter_file_extensions(tree_obj_path)
return True
def _load_data_from_commit(self, commit_sha: str) -> List[Document]:
"""
Load data from a commit.
Loads github repository data from a specific commit sha.
:param `commit`: commit sha
:return: list of documents
"""
commit_response: GitCommitResponseModel = self._loop.run_until_complete(
self._github_client.get_commit(self._owner, self._repo, commit_sha)
)
tree_sha = commit_response.commit.tree.sha
blobs_and_paths = self._loop.run_until_complete(self._recurse_tree(tree_sha))
print_if_verbose(self._verbose, f"got {len(blobs_and_paths)} blobs")
return self._loop.run_until_complete(
self._generate_documents(blobs_and_paths=blobs_and_paths)
)
def _load_data_from_branch(self, branch: str) -> List[Document]:
"""
Load data from a branch.
Loads github repository data from a specific branch.
:param `branch`: branch name
:return: list of documents
"""
branch_data: GitBranchResponseModel = self._loop.run_until_complete(
self._github_client.get_branch(self._owner, self._repo, branch)
)
tree_sha = branch_data.commit.commit.tree.sha
blobs_and_paths = self._loop.run_until_complete(self._recurse_tree(tree_sha))
print_if_verbose(self._verbose, f"got {len(blobs_and_paths)} blobs")
return self._loop.run_until_complete(
self._generate_documents(blobs_and_paths=blobs_and_paths)
)
def load_data(
self,
commit_sha: Optional[str] = None,
branch: Optional[str] = None,
) -> List[Document]:
"""
Load data from a commit or a branch.
Loads github repository data from a specific commit sha or a branch.
:param `commit`: commit sha
:param `branch`: branch name
:return: list of documents
"""
if commit_sha is not None and branch is not None:
raise ValueError("You can only specify one of commit or branch.")
if commit_sha is None and branch is None:
raise ValueError("You must specify one of commit or branch.")
if commit_sha is not None:
return self._load_data_from_commit(commit_sha)
if branch is not None:
return self._load_data_from_branch(branch)
raise ValueError("You must specify one of commit or branch.")
async def _recurse_tree(
self, tree_sha: str, current_path: str = "", current_depth: int = 0
) -> Any:
"""
Recursively get all blob tree objects in a tree.
And construct their full path relative to the root of the repository.
(see GitTreeResponseModel.GitTreeObject in
github_api_client.py for more information)
:param `tree_sha`: sha of the tree to recurse
:param `current_path`: current path of the tree
:param `current_depth`: current depth of the tree
:return: list of tuples of
(tree object, file's full path realtive to the root of the repo)
"""
blobs_and_full_paths: List[Tuple[GitTreeResponseModel.GitTreeObject, str]] = []
print_if_verbose(
self._verbose, "\t" * current_depth + f"current path: {current_path}"
)
tree_data: GitTreeResponseModel = await self._github_client.get_tree(
self._owner, self._repo, tree_sha
)
print_if_verbose(
self._verbose, "\t" * current_depth + f"processing tree {tree_sha}"
)
for tree_obj in tree_data.tree:
file_path = os.path.join(current_path, tree_obj.path)
if tree_obj.type == "tree":
print_if_verbose(
self._verbose,
"\t" * current_depth + f"recursing into {tree_obj.path}",
)
if not self._check_filter_directories(file_path):
print_if_verbose(
self._verbose,
"\t" * current_depth + f"ignoring directory {tree_obj.path} due to filter",
)
continue
blobs_and_full_paths.extend(
await self._recurse_tree(tree_obj.sha, file_path, current_depth + 1)
)
elif tree_obj.type == "blob":
print_if_verbose(
self._verbose, "\t" * current_depth + f"found blob {tree_obj.path}"
)
if not self._check_filter_file_extensions(file_path):
print_if_verbose(
self._verbose,
"\t" * current_depth + f"ignoring file {tree_obj.path} due to filter",
)
continue
blobs_and_full_paths.append((tree_obj, file_path))
return blobs_and_full_paths
async def _generate_documents(
self, blobs_and_paths: List[Tuple[GitTreeResponseModel.GitTreeObject, str]]
) -> List[Document]:
"""
Generate documents from a list of blobs and their full paths.
:param `blobs_and_paths`: list of tuples of
(tree object, file's full path in the repo realtive to the root of the repo)
:return: list of documents
"""
buffered_iterator = BufferedGitBlobDataIterator(
blobs_and_paths=blobs_and_paths,
github_client=self._github_client,
owner=self._owner,
repo=self._repo,
loop=self._loop,
buffer_size=self._concurrent_requests, # TODO: make this configurable
verbose=self._verbose,
)
documents = []
async for blob_data, full_path in buffered_iterator:
print_if_verbose(self._verbose, f"generating document for {full_path}")
assert (
blob_data.encoding == "base64"
), f"blob encoding {blob_data.encoding} not supported"
decoded_bytes = None
try:
decoded_bytes = base64.b64decode(blob_data.content)
del blob_data.content
except binascii.Error:
print_if_verbose(
self._verbose, f"could not decode {full_path} as base64"
)
continue
if self._use_parser:
document = self._parse_supported_file(
file_path=full_path,
file_content=decoded_bytes,
tree_sha=blob_data.sha,
tree_path=full_path,
)
if document is not None:
documents.append(document)
continue
print_if_verbose(
self._verbose,
f"could not parse {full_path} as a supported file type"
+ " - falling back to decoding as utf-8 raw text",
)
try:
if decoded_bytes is None:
raise ValueError("decoded_bytes is None")
decoded_text = decoded_bytes.decode("utf-8")
except UnicodeDecodeError:
print_if_verbose(
self._verbose, f"could not decode {full_path} as utf-8"
)
continue
print_if_verbose(
self._verbose,
f"got {len(decoded_text)} characters"
+ f"- adding to documents - {full_path}",
)
document = Document(
text=decoded_text,
doc_id=blob_data.sha,
extra_info={
"file_path": full_path,
"file_name": full_path.split("/")[-1],
},
)
documents.append(document)
return documents
def _parse_supported_file(
self, file_path: str, file_content: bytes, tree_sha: str, tree_path: str
) -> Optional[Document]:
"""
Parse a file if it is supported by a parser.
:param `file_path`: path of the file in the repo
:param `file_content`: content of the file
:return: Document if the file is supported by a parser, None otherwise
"""
file_extension = get_file_extension(file_path)
if (parser := DEFAULT_FILE_EXTRACTOR.get(file_extension)) is not None:
parser.init_parser()
print_if_verbose(
self._verbose,
f"parsing {file_path}"
+ f"as {file_extension} with "
+ f"{parser.__class__.__name__}",
)
with tempfile.TemporaryDirectory() as tmpdirname:
with tempfile.NamedTemporaryFile(
dir=tmpdirname,
suffix=f".{file_extension}",
mode="w+b",
delete=False,
) as tmpfile:
print_if_verbose(
self._verbose,
"created a temporary file"
+ f"{tmpfile.name} for parsing {file_path}",
)
tmpfile.write(file_content)
tmpfile.flush()
tmpfile.close()
try:
parsed_file = parser.parse_file(pathlib.Path(tmpfile.name))
parsed_file = "\n\n".join(parsed_file)
except Exception as e:
print_if_verbose(
self._verbose, f"error while parsing {file_path}"
)
logger.error(
"Error while parsing "
+ f"{file_path} with "
+ f"{parser.__class__.__name__}:\n{e}"
)
parsed_file = None
finally:
os.remove(tmpfile.name)
if parsed_file is None:
return None
return Document(
text=parsed_file,
doc_id=tree_sha,
extra_info={
"file_path": file_path,
"file_name": tree_path,
},
)
return None
if __name__ == "__main__":
import time
def timeit(func: Callable) -> Callable:
"""Time a function."""
def wrapper(*args: Any, **kwargs: Any) -> None:
"""Callcuate time taken to run a function."""
start = time.time()
func(*args, **kwargs)
end = time.time()
print(f"Time taken: {end - start} seconds for {func.__name__}")
return wrapper
github_client = GithubClient(github_token=os.environ["GITHUB_TOKEN"], verbose=True)
reader1 = GithubRepositoryReader(
github_client=github_client,
owner="jerryjliu",
repo="gpt_index",
use_parser=False,
verbose=True,
filter_directories=(
["docs"],
GithubRepositoryReader.FilterType.INCLUDE,
),
filter_file_extensions=(
[".png", ".jpg", ".jpeg", ".gif", ".svg", ".ico", "json", ".ipynb"],
GithubRepositoryReader.FilterType.EXCLUDE,
),
)
@timeit
def load_data_from_commit() -> None:
"""Load data from a commit."""
documents = reader1.load_data(
commit_sha="22e198b3b166b5facd2843d6a62ac0db07894a13"
)
for document in documents:
print(document.extra_info)
@timeit
def load_data_from_branch() -> None:
"""Load data from a branch."""
documents = reader1.load_data(branch="main")
for document in documents:
print(document.extra_info)
input("Press enter to load github repository from branch name...")
load_data_from_branch()
# input("Press enter to load github repository from commit sha...")
# load_data_from_commit()

View File

@ -0,0 +1,435 @@
"""
Github API client for the GPT-Index library.
This module contains the Github API client for the GPT-Index library.
It is used by the Github readers to retrieve the data from Github.
"""
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from dataclasses_json import DataClassJsonMixin
@dataclass
class GitTreeResponseModel(DataClassJsonMixin):
"""
Dataclass for the response from the Github API's getTree endpoint.
Attributes:
- sha (str): SHA1 checksum ID of the tree.
- url (str): URL for the tree.
- tree (List[GitTreeObject]): List of objects in the tree.
- truncated (bool): Whether the tree is truncated.
Examples:
>>> tree = client.get_tree("owner", "repo", "branch")
>>> tree.sha
"""
@dataclass
class GitTreeObject(DataClassJsonMixin):
"""
Dataclass for the objects in the tree.
Attributes:
- path (str): Path to the object.
- mode (str): Mode of the object.
- type (str): Type of the object.
- sha (str): SHA1 checksum ID of the object.
- url (str): URL for the object.
- size (Optional[int]): Size of the object (only for blobs).
"""
path: str
mode: str
type: str
sha: str
url: str
size: Optional[int] = None
sha: str
url: str
tree: List[GitTreeObject]
truncated: bool
@dataclass
class GitBlobResponseModel(DataClassJsonMixin):
"""
Dataclass for the response from the Github API's getBlob endpoint.
Attributes:
- content (str): Content of the blob.
- encoding (str): Encoding of the blob.
- url (str): URL for the blob.
- sha (str): SHA1 checksum ID of the blob.
- size (int): Size of the blob.
- node_id (str): Node ID of the blob.
"""
content: str
encoding: str
url: str
sha: str
size: int
node_id: str
@dataclass
class GitCommitResponseModel(DataClassJsonMixin):
"""
Dataclass for the response from the Github API's getCommit endpoint.
Attributes:
- tree (Tree): Tree object for the commit.
"""
@dataclass
class Commit(DataClassJsonMixin):
"""Dataclass for the commit object in the commit. (commit.commit)."""
@dataclass
class Tree(DataClassJsonMixin):
"""
Dataclass for the tree object in the commit.
Attributes:
- sha (str): SHA for the commit
"""
sha: str
tree: Tree
commit: Commit
@dataclass
class GitBranchResponseModel(DataClassJsonMixin):
"""
Dataclass for the response from the Github API's getBranch endpoint.
Attributes:
- commit (Commit): Commit object for the branch.
"""
@dataclass
class Commit(DataClassJsonMixin):
"""Dataclass for the commit object in the branch. (commit.commit)."""
@dataclass
class Commit(DataClassJsonMixin):
"""Dataclass for the commit object in the commit. (commit.commit.tree)."""
@dataclass
class Tree(DataClassJsonMixin):
"""
Dataclass for the tree object in the commit.
Usage: commit.commit.tree.sha
"""
sha: str
tree: Tree
commit: Commit
commit: Commit
from typing import Protocol
class BaseGithubClient(Protocol):
def get_all_endpoints(self) -> Dict[str, str]:
...
async def request(
self,
endpoint: str,
method: str,
headers: Dict[str, Any] = {},
**kwargs: Any,
) -> Any:
...
async def get_tree(
self,
owner: str,
repo: str,
tree_sha: str,
) -> GitTreeResponseModel:
...
async def get_blob(
self,
owner: str,
repo: str,
file_sha: str,
) -> GitBlobResponseModel:
...
async def get_commit(
self,
owner: str,
repo: str,
commit_sha: str,
) -> GitCommitResponseModel:
...
async def get_branch(
self,
owner: str,
repo: str,
branch_name: str,
) -> GitBranchResponseModel:
...
class GithubClient:
"""
An asynchronous client for interacting with the Github API.
This client is used for making API requests to Github.
It provides methods for accessing the Github API endpoints.
The client requires a Github token for authentication,
which can be passed as an argument or set as an environment variable.
If no Github token is provided, the client will raise a ValueError.
Examples:
>>> client = GithubClient("my_github_token")
>>> branch_info = client.get_branch("owner", "repo", "branch")
"""
DEFAULT_BASE_URL = "https://api.github.com"
DEFAULT_API_VERSION = "2022-11-28"
def __init__(
self,
github_token: Optional[str] = None,
base_url: str = DEFAULT_BASE_URL,
api_version: str = DEFAULT_API_VERSION,
verbose: bool = False,
) -> None:
"""
Initialize the GithubClient.
Args:
- github_token (str): Github token for authentication.
If not provided, the client will try to get it from
the GITHUB_TOKEN environment variable.
- base_url (str): Base URL for the Github API
(defaults to "https://api.github.com").
- api_version (str): Github API version (defaults to "2022-11-28").
Raises:
ValueError: If no Github token is provided.
"""
if github_token is None:
github_token = os.getenv("GITHUB_TOKEN")
if github_token is None:
raise ValueError(
"Please provide a Github token. "
+ "You can do so by passing it as an argument to the GithubReader,"
+ "or by setting the GITHUB_TOKEN environment variable."
)
self._base_url = base_url
self._api_version = api_version
self._verbose = verbose
self._endpoints = {
"getTree": "/repos/{owner}/{repo}/git/trees/{tree_sha}",
"getBranch": "/repos/{owner}/{repo}/branches/{branch}",
"getBlob": "/repos/{owner}/{repo}/git/blobs/{file_sha}",
"getCommit": "/repos/{owner}/{repo}/commits/{commit_sha}",
}
self._headers = {
"Accept": "application/vnd.github+json",
"Authorization": f"Bearer {github_token}",
"X-GitHub-Api-Version": f"{self._api_version}",
}
def get_all_endpoints(self) -> Dict[str, str]:
"""Get all available endpoints."""
return {**self._endpoints}
async def request(
self,
endpoint: str,
method: str,
headers: Dict[str, Any] = {},
**kwargs: Any,
) -> Any:
"""
Make an API request to the Github API.
This method is used for making API requests to the Github API.
It is used internally by the other methods in the client.
Args:
- `endpoint (str)`: Name of the endpoint to make the request to.
- `method (str)`: HTTP method to use for the request.
- `headers (dict)`: HTTP headers to include in the request.
- `**kwargs`: Keyword arguments to pass to the endpoint URL.
Returns:
- `response (httpx.Response)`: Response from the API request.
Raises:
- ImportError: If the `httpx` library is not installed.
- httpx.HTTPError: If the API request fails.
Examples:
>>> response = client.request("getTree", "GET",
owner="owner", repo="repo",
tree_sha="tree_sha")
"""
try:
import httpx
except ImportError:
raise ImportError(
"Please install httpx to use the GithubRepositoryReader. "
"You can do so by running `pip install httpx`."
)
_headers = {**self._headers, **headers}
_client: httpx.AsyncClient
async with httpx.AsyncClient(
headers=_headers, base_url=self._base_url
) as _client:
try:
response = await _client.request(
method, url=self._endpoints[endpoint].format(**kwargs)
)
except httpx.HTTPError as excp:
print(f"HTTP Exception for {excp.request.url} - {excp}")
raise excp
return response
async def get_branch(
self, owner: str, repo: str, branch: str
) -> GitBranchResponseModel:
"""
Get information about a branch. (Github API endpoint: getBranch).
Args:
- `owner (str)`: Owner of the repository.
- `repo (str)`: Name of the repository.
- `branch (str)`: Name of the branch.
Returns:
- `branch_info (GitBranchResponseModel)`: Information about the branch.
Examples:
>>> branch_info = client.get_branch("owner", "repo", "branch")
"""
return GitBranchResponseModel.from_json(
(
await self.request(
"getBranch", "GET", owner=owner, repo=repo, branch=branch
)
).text
)
async def get_tree(
self, owner: str, repo: str, tree_sha: str
) -> GitTreeResponseModel:
"""
Get information about a tree. (Github API endpoint: getTree).
Args:
- `owner (str)`: Owner of the repository.
- `repo (str)`: Name of the repository.
- `tree_sha (str)`: SHA of the tree.
Returns:
- `tree_info (GitTreeResponseModel)`: Information about the tree.
Examples:
>>> tree_info = client.get_tree("owner", "repo", "tree_sha")
"""
return GitTreeResponseModel.from_json(
(
await self.request(
"getTree", "GET", owner=owner, repo=repo, tree_sha=tree_sha
)
).text
)
async def get_blob(
self, owner: str, repo: str, file_sha: str
) -> GitBlobResponseModel:
"""
Get information about a blob. (Github API endpoint: getBlob).
Args:
- `owner (str)`: Owner of the repository.
- `repo (str)`: Name of the repository.
- `file_sha (str)`: SHA of the file.
Returns:
- `blob_info (GitBlobResponseModel)`: Information about the blob.
Examples:
>>> blob_info = client.get_blob("owner", "repo", "file_sha")
"""
return GitBlobResponseModel.from_json(
(
await self.request(
"getBlob", "GET", owner=owner, repo=repo, file_sha=file_sha
)
).text
)
async def get_commit(
self, owner: str, repo: str, commit_sha: str
) -> GitCommitResponseModel:
"""
Get information about a commit. (Github API endpoint: getCommit).
Args:
- `owner (str)`: Owner of the repository.
- `repo (str)`: Name of the repository.
- `commit_sha (str)`: SHA of the commit.
Returns:
- `commit_info (GitCommitResponseModel)`: Information about the commit.
Examples:
>>> commit_info = client.get_commit("owner", "repo", "commit_sha")
"""
return GitCommitResponseModel.from_json(
(
await self.request(
"getCommit", "GET", owner=owner, repo=repo, commit_sha=commit_sha
)
).text
)
if __name__ == "__main__":
import asyncio
async def main() -> None:
"""Test the GithubClient."""
client = GithubClient()
response = await client.get_tree(
owner="ahmetkca", repo="CommitAI", tree_sha="with-body"
)
for obj in response.tree:
if obj.type == "blob":
print(obj.path)
print(obj.sha)
blob_response = await client.get_blob(
owner="ahmetkca", repo="CommitAI", file_sha=obj.sha
)
print(blob_response.content)
asyncio.run(main())

View File

@ -0,0 +1 @@
httpx

View File

@ -0,0 +1,170 @@
"""
Github readers utils.
This module contains utility functions for the Github readers.
"""
import asyncio
import os
import time
from abc import ABC, abstractmethod
from typing import List, Tuple
from gpt_index.readers.github_readers.github_api_client import (
GitBlobResponseModel,
GithubClient,
GitTreeResponseModel,
)
def print_if_verbose(verbose: bool, message: str) -> None:
"""Log message if verbose is True."""
if verbose:
print(message)
def get_file_extension(filename: str) -> str:
"""Get file extension."""
return f".{os.path.splitext(filename)[1][1:].lower()}"
class BufferedAsyncIterator(ABC):
"""
Base class for buffered async iterators.
This class is to be used as a base class for async iterators
that need to buffer the results of an async operation.
The async operation is defined in the _fill_buffer method.
The _fill_buffer method is called when the buffer is empty.
"""
def __init__(self, buffer_size: int):
"""
Initialize params.
Args:
- `buffer_size (int)`: Size of the buffer.
It is also the number of items that will
be retrieved from the async operation at once.
see _fill_buffer. Defaults to 2. Setting it to 1
will result in the same behavior as a synchronous iterator.
"""
self._buffer_size = buffer_size
self._buffer: List[Tuple[GitBlobResponseModel, str]] = []
self._index = 0
@abstractmethod
async def _fill_buffer(self) -> None:
raise NotImplementedError
def __aiter__(self) -> "BufferedAsyncIterator":
"""Return the iterator object."""
return self
async def __anext__(self) -> Tuple[GitBlobResponseModel, str]:
"""
Get next item.
Returns:
- `item (Tuple[GitBlobResponseModel, str])`: Next item.
Raises:
- `StopAsyncIteration`: If there are no more items.
"""
if not self._buffer:
await self._fill_buffer()
if not self._buffer:
raise StopAsyncIteration
item = self._buffer.pop(0)
self._index += 1
return item
class BufferedGitBlobDataIterator(BufferedAsyncIterator):
"""
Buffered async iterator for Git blobs.
This class is an async iterator that buffers the results of the get_blob operation.
It is used to retrieve the contents of the files in a Github repository.
getBlob endpoint supports up to 100 megabytes of content for blobs.
This concrete implementation of BufferedAsyncIterator allows you to lazily retrieve
the contents of the files in a Github repository.
Otherwise you would have to retrieve all the contents of
the files in the repository at once, which would
be problematic if the repository is large.
"""
def __init__(
self,
blobs_and_paths: List[Tuple[GitTreeResponseModel.GitTreeObject, str]],
github_client: GithubClient,
owner: str,
repo: str,
loop: asyncio.AbstractEventLoop,
buffer_size: int,
verbose: bool = False,
):
"""
Initialize params.
Args:
- blobs_and_paths (List[Tuple[GitTreeResponseModel.GitTreeObject, str]]):
List of tuples containing the blob and the path of the file.
- github_client (GithubClient): Github client.
- owner (str): Owner of the repository.
- repo (str): Name of the repository.
- loop (asyncio.AbstractEventLoop): Event loop.
- buffer_size (int): Size of the buffer.
"""
super().__init__(buffer_size)
self._blobs_and_paths = blobs_and_paths
self._github_client = github_client
self._owner = owner
self._repo = repo
self._verbose = verbose
if loop is None:
loop = asyncio.get_event_loop()
if loop is None:
raise ValueError("No event loop found")
async def _fill_buffer(self) -> None:
"""
Fill the buffer with the results of the get_blob operation.
The get_blob operation is called for each blob in the blobs_and_paths list.
The blobs are retrieved in batches of size buffer_size.
"""
del self._buffer[:]
self._buffer = []
start = self._index
end = min(start + self._buffer_size, len(self._blobs_and_paths))
if start >= end:
return
if self._verbose:
start_t = time.time()
results: List[GitBlobResponseModel] = await asyncio.gather(
*[
self._github_client.get_blob(self._owner, self._repo, blob.sha)
for blob, _ in self._blobs_and_paths[
start:end
] # TODO: use batch_size instead of buffer_size for concurrent requests
]
)
if self._verbose:
end_t = time.time()
blob_names_and_sizes = [
(blob.path, blob.size) for blob, _ in self._blobs_and_paths[start:end]
]
print(
"Time to get blobs ("
+ f"{blob_names_and_sizes}"
+ f"): {end_t - start_t:.2f} seconds"
)
self._buffer = [
(result, path)
for result, (_, path) in zip(results, self._blobs_and_paths[start:end])
]

View File

@ -26,7 +26,11 @@
"CJKPDFReader": {
"id": "file/cjk_pdf",
"author": "JiroShimaya",
"keywords": ["Japanese", "Chinese", "Korean"]
"keywords": [
"Japanese",
"Chinese",
"Korean"
]
},
"DocxReader": {
"id": "file/docx",
@ -39,7 +43,10 @@
"ImageReader": {
"id": "file/image",
"author": "ravi03071991",
"keywords": ["invoice", "receipt"]
"keywords": [
"invoice",
"receipt"
]
},
"EpubReader": {
"id": "file/epub",
@ -68,17 +75,30 @@
"BeautifulSoupWebReader": {
"id": "web/beautiful_soup_web",
"author": "thejessezhang",
"keywords": ["substack", "readthedocs", "documentation"]
"keywords": [
"substack",
"readthedocs",
"documentation"
]
},
"RssReader": {
"id": "web/rss",
"author": "bborn",
"keywords": ["feed", "rss", "atom"]
"keywords": [
"feed",
"rss",
"atom"
]
},
"DatabaseReader": {
"id": "database",
"author": "kevinqz",
"keywords": ["sql", "postgres", "snowflake", "aws rds"]
"keywords": [
"sql",
"postgres",
"snowflake",
"aws rds"
]
},
"DiscordReader": {
"id": "discord",
@ -154,22 +174,39 @@
"UnstructuredReader": {
"id": "file/unstructured",
"author": "thejessezhang",
"keywords": ["sec", "html", "eml", "10k", "10q", "unstructured.io"]
"keywords": [
"sec",
"html",
"eml",
"10k",
"10q",
"unstructured.io"
]
},
"KnowledgeBaseWebReader": {
"id": "web/knowledge_base",
"author": "jasonwcfan",
"keywords": ["documentation"]
"keywords": [
"documentation"
]
},
"S3Reader": {
"id": "s3",
"author": "thejessezhang",
"keywords": ["aws s3", "bucket", "amazon web services"]
"keywords": [
"aws s3",
"bucket",
"amazon web services"
]
},
"RemoteReader": {
"id": "remote",
"author": "thejessezhang",
"keywords": ["hosted", "url", "gutenberg"]
"keywords": [
"hosted",
"url",
"gutenberg"
]
},
"RemoteDepthReader": {
"id": "remote_depth",
@ -183,12 +220,18 @@
"DadJokesReader": {
"id": "dad_jokes",
"author": "sidu",
"keywords": ["jokes", "dad jokes"]
"keywords": [
"jokes",
"dad jokes"
]
},
"WhatsappChatLoader": {
"id": "whatsapp",
"author": "batmanscode",
"keywords": ["whatsapp", "chat"]
"keywords": [
"whatsapp",
"chat"
]
},
"BilibiliTranscriptReader": {
"id": "bilibili",
@ -197,16 +240,45 @@
"RedditReader": {
"id": "reddit",
"author": "vanessahlyan",
"keywords": ["reddit", "subreddit", "search", "comments"]
"keywords": [
"reddit",
"subreddit",
"search",
"comments"
]
},
"MemosReader": {
"id": "memos",
"author": "bubu",
"keywords": ["memos", "note"]
"keywords": [
"memos",
"note"
]
},
"SpotifyReader": {
"id": "spotify",
"author": "ong",
"keywords": [
"spotify",
"music"
]
},
"GithubRepositoryReader": {
"id": "github_repo",
"author": "ahmetkca",
"keywords": [
"github",
"repository",
"git",
"code",
"source code",
"placeholder"
],
"extra_files": [
"github_client.py",
"utils.py",
"__init__.py"
]
"keywords": ["spotify", "music"]
},
"RDFReader": {
@ -214,4 +286,4 @@
"author": "mommi84",
"keywords": ["rdf", "n-triples", "graph", "knowledge graph"]
}
}
}

104
tests/test_github_reader.py Normal file
View File

@ -0,0 +1,104 @@
from typing import List, Tuple
import unittest
from unittest.mock import MagicMock
import pytest
from unittest.mock import AsyncMock
# Skip by default due to network request.
# Remove this to test changes to GithubRepositoryReader.
pytest.skip()
from loader_hub.github_repo import GithubRepositoryReader, GithubClient
from gpt_index import Document
import os
import pytest
import base64
from loader_hub.github_repo import GithubClient
@pytest.fixture
def github_client():
return GithubClient(
github_token=os.getenv("GITHUB_API_TOKEN"),
verbose= True,
)
@pytest.mark.asyncio
async def test_github_client(github_client):
owner = "emptycrown"
repo = "llama-hub"
branch = "main"
commit_sha = "0cd691322e5244b48b68e3588d1343eb53f3a112" # Points to Add spotify reader, https://github.com/emptycrown/llama-hub/commit/0cd691322e5244b48b68e3588d1343eb53f3a112
# test get_branch
branch_data = await github_client.get_branch(owner, repo, branch)
assert branch_data.name == branch
assert branch_data._links.self == f"https://api.github.com/repos/{owner}/{repo}/branches/{branch}", "Branch self link is incorrect"
assert branch_data._links.html == f"https://github.com/{owner}/{repo}/tree/{branch}", "Branch html link is incorrect"
# test get_commit
commit_data = await github_client.get_commit(owner, repo, commit_sha)
assert commit_data.sha == commit_sha, "Commit sha is incorrect"
assert commit_data.url == f"https://api.github.com/repos/{owner}/{repo}/commits/{commit_sha}", "Commit url is incorrect"
# test get_tree
tree_data = await github_client.get_tree(owner, repo, commit_data.commit.tree.sha)
assert tree_data.url == f"https://api.github.com/repos/{owner}/{repo}/git/trees/{commit_data.commit.tree.sha}", "Tree url is incorrect"
assert tree_data.sha == commit_data.commit.tree.sha, "Tree sha is incorrect"
print(tree_data.tree[0].sha)
assert 1 == 1
# test get_blob
expected_files_in_first_depth_of_the_tree: List[Tuple[str, str]] = [
("test_requirements.txt", "blob"),
("README.md", "blob"),
("Makefile", "blob"),
(".gitignore", "blob"),
("tests", "tree"),
("loader_hub", "tree"),
(".github", "tree"),
]
# check if the first depth of the tree has the expected files. All the expected files should be in the first depth of the tree and vice versa
assert len(tree_data.tree) == len(expected_files_in_first_depth_of_the_tree), "The number of files in the first depth of the tree is incorrect"
for file in expected_files_in_first_depth_of_the_tree:
assert file in [(tree_file.path, tree_file.type) for tree_file in tree_data.tree], f"{file} is not in the first depth of the tree"
# checking the opposite
for tree_obj in tree_data.tree:
assert (tree_obj.path, tree_obj.type) in expected_files_in_first_depth_of_the_tree, f"{tree_obj.path} is not in the expected files"
# find test_reqirements.txt in the tree
test_requirements_txt = [tree_obj for tree_obj in tree_data.tree if tree_obj.path == "test_requirements.txt"][0]
# test get_blob
blob_data = await github_client.get_blob(owner, repo, test_requirements_txt.sha)
assert blob_data.encoding == "base64", "Blob encoding is incorrect"
assert blob_data.url == f"https://api.github.com/repos/{owner}/{repo}/git/blobs/{test_requirements_txt.sha}", "Blob url is incorrect"
assert blob_data.sha == test_requirements_txt.sha, "Blob sha is incorrect"
# decode blob content base64-decoded string to utf-8
decoded_blob_content = base64.b64decode(blob_data.content).decode("utf-8")
expected_decoded_blob_content = """
# For testing
pytest==7.2.1
pytest-dotenv==0.5.2
# TODO: remove gpt_index after migration
https://github.com/jerryjliu/gpt_index/archive/master.zip
llama-index
# For linting
# linting stubs
types-requests==2.28.11.8
# formatting
black==22.12.0
isort==5.11.4
"""
# check if the decoded blob content is correct
for dbc in zip(filter( lambda x: x != "", decoded_blob_content.splitlines()), filter( lambda x: x != "", expected_decoded_blob_content.splitlines())):
assert dbc[0] == dbc[1], f"{dbc[0]} is not equal to {dbc[1]}"