James f64be68bcc
Beta version for stackoverflow teams connections (#188)
Co-authored-by: Jerry Liu <jerryjliu98@gmail.com>
2023-04-13 00:42:01 -07:00

153 lines
6.1 KiB
Python

import json
import logging
import os
import threading
import time
from dataclasses import dataclass
from datetime import datetime
from functools import wraps
from typing import List, Optional
import requests
from llama_index.readers.base import BaseReader
from llama_index.readers.schema.base import Document
logger = logging.getLogger(__name__)
@dataclass
class StackOverflowPost:
link: str
score: int
last_activity_date: int
creation_date: int
post_id: Optional[int] = None
post_type: Optional[str] = None
body_markdown: Optional[str] = None
owner_account_id: Optional[int] = None
owner_reputation: Optional[int] = None
owner_user_id: Optional[int] = None
owner_user_type: Optional[str] = None
owner_profile_image: Optional[str] = None
owner_display_name: Optional[str] = None
owner_link: Optional[str] = None
title: Optional[str] = None
last_edit_date: Optional[str] = None
tags: Optional[List[str]] = None
view_count: Optional[int] = None
article_id: Optional[int] = None
article_type: Optional[str] = None
def rate_limit(*, allowed_per_second: int):
max_period = 1.0 / allowed_per_second
last_call = [time.perf_counter()]
lock = threading.Lock()
def decorate(func):
@wraps(func)
def limit(*args, **kwargs):
with lock:
elapsed = time.perf_counter() - last_call[0]
hold = max_period - elapsed
if hold > 0:
time.sleep(hold)
result = func(*args, **kwargs)
last_call[0] = time.perf_counter()
return result
return limit
return decorate
@rate_limit(allowed_per_second=15)
def rate_limited_get(url, headers):
'''
https://api.stackoverflowteams.com/docs/throttle
https://api.stackexchange.com/docs/throttle
Every application is subject to an IP based concurrent request throttle.
If a single IP is making more than 30 requests a second, new requests will be dropped.
The exact ban period is subject to change, but will be on the order of 30 seconds to a few minutes typically.
Note that exactly what response an application gets (in terms of HTTP code, text, and so on)
is undefined when subject to this ban; we consider > 30 request/sec per IP to be very abusive and thus cut the requests off very harshly.
'''
resp = requests.get(url, headers=headers)
if resp.status_code == 429:
logger.warning('Rate limited, sleeping for 5 minutes')
time.sleep(300)
return rate_limited_get(url, headers)
return resp
class StackoverflowReader(BaseReader):
def __init__(self, api_key: str = None, team_name: str = None, cache_dir: str = None) -> None:
self._api_key = api_key or os.environ.get('STACKOVERFLOW_PAT')
self._team_name = team_name or os.environ.get('STACKOVERFLOW_TEAM_NAME')
self._last_index_time = None # TODO
self._cache_dir = cache_dir
if self._cache_dir:
os.makedirs(self._cache_dir, exist_ok=True)
def load_data(self, page: int = 1, doc_type: str = 'posts', limit: int = 50) -> List[Document]:
data = []
has_more = True
while has_more:
url = self.build_url(page, doc_type)
headers = {'X-API-Access-Token': self._api_key}
fp = os.path.join(self._cache_dir, f'{doc_type}_{page}.json')
response = {}
if self._cache_dir and os.path.exists(fp) and os.path.getsize(fp) > 0:
try:
with open(fp, 'r') as f:
response = f.read()
response = json.loads(response)
except Exception as e:
logger.error(e)
if not response:
response = rate_limited_get(url, headers)
response.raise_for_status()
if self._cache_dir:
with open(os.path.join(self._cache_dir, f'{doc_type}_{page}.json'), 'w') as f:
f.write(response.content.decode('utf-8'))
logger.info(f'Wrote {fp} to cache')
response = response.json()
has_more = response['has_more']
items = response['items']
logger.info(f'Fetched {len(items)} {doc_type} from Stack Overflow')
for item_dict in items:
owner_fields = {}
if 'owner' in item_dict:
owner_fields = {f"owner_{k}": v for k, v in item_dict.pop('owner').items()}
if 'title' not in item_dict:
item_dict['title'] = item_dict['link']
post = StackOverflowPost(**item_dict, **owner_fields)
# TODO: filter out old posts
# last_modified = datetime.fromtimestamp(post.last_edit_date or post.last_activity_date)
# if last_modified < self._last_index_time:
# return data
post_document = Document(text=post.body_markdown, doc_id=post.post_id,
extra_info={"title": post.title, "author": post.owner_display_name,
"timestamp": datetime.fromtimestamp(post.creation_date), "location": post.link,
"url": post.link, "author_image_url": post.owner_profile_image,
"type": post.post_type})
data.append(post_document)
if has_more:
page += 1
return data
def build_url(self, page: int, doc_type: str) -> str:
team_fragment = f'&team={self._team_name}'
# not sure if this filter is shared globally, or only to a particular team
filter_fragment = '&filter=!nOedRLbqzB'
page_fragment = f'&page={page}'
url = f'https://api.stackoverflowteams.com/2.3/{doc_type}?{team_fragment}{filter_fragment}{page_fragment}'
return url
if __name__ == "__main__":
reader = StackoverflowReader(os.environ.get('STACKOVERFLOW_PAT'), os.environ.get('STACKOVERFLOW_TEAM_NAME'), cache_dir='./stackoverflow_cache')
# reader.load_data()