Skip to content

Commit

Permalink
better thread handling
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Tu committed Jan 12, 2017
1 parent 4b5208f commit 7b77aca
Showing 1 changed file with 112 additions and 112 deletions.
224 changes: 112 additions & 112 deletions eia/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,44 +18,73 @@
except ImportError:
import json

__all__ = ['Series', 'Geoset', 'Relation', 'Category', 'SeriesCategory', 'Updates', 'Search', 'BaseQuery']
__all__ = ['Series', 'Geoset', 'Relation', 'Category',
'SeriesCategory', 'Updates', 'Search', 'BaseQuery']

requests_cache.install_cache(backend="memory",
expire_after = 600,
ignored_parameters = "api_key")
expire_after=600,
ignored_parameters="api_key")
# Monkey patch requests_cache, Ideally, we can make the backend
# user-configureable


def chunk(l, n):
"""Chunk a list into n parts."""
for i in range(0, len(l), n):
yield l[i:i+n]
yield l[i:i + n]


class BaseQuery(object):

host = "https://api.eia.gov"
endpoint = ""

def __init__(self, api_key, output="json"):
def __init__(self, api_key, output="json", consumers=4):
self.api_key = api_key
self.default_params = {"api_key" : self.api_key, "out" : output}
self.default_params = {"api_key": self.api_key, "out": output}
self.base_url = urljoin(self.host, self.endpoint)
if hasattr(self, 'parse'):
self.queue = Queue()
self.results = Queue()
self.consumers = []
for n in range(consumers):
consumer = Thread(target=self._consume)
consumer.daemon = True
consumer.start()
self.consumers.append(consumer)

def _consume(self):
"""Lightweight server for non-blocking requests
"""
while True:
job = self.queue.get()
try:
result = self.parse(job)
self.results.put(result)
except Exception as e:
warn(e)
self.queue.task_done()

def get(self, **kwargs):
params = dict(self.default_params)
params.update(kwargs)
r = requests.get(self.base_url, params = params)
r = requests.get(self.base_url, params=params)
r.raise_for_status()
return r.json()

def post(self, **data):
r = requests.post(self.base_url, params=self.default_params, data=data)
r.raise_for_status()
return r.json()
try:
r.raise_for_status()
return r.json()
except Exception as e:
print(e)

def query(self, *args, **kwargs):
"""If a 'parse' method is defined, this method should queue jobs."""
raise NotImplementedError


class Series(BaseQuery):
"""
Expand All @@ -72,39 +101,32 @@ class Series(BaseQuery):

endpoint = "series/"

def parse(self, q, series_id, **kwargs):
"""Thread target to put results on a queue."""
if len(series_id) + len(self.base_url) >= 1800:
q.put(self.post(series_id = series_id, **kwargs))
else:
q.put(self.get(series_id=series_id, **kwargs))

def query(self, *series_ids, **kwargs):
q = Queue(maxsize=1000)
threads = []
for c in chunk(list(set(series_ids)), 100):
t = Thread(target = self.parse,
args = (q, ';'.join(c)), kwargs=kwargs)
t.daemon = True
t.start()
threads.append(t)
map(lambda x : x.join(), threads)
data = filter(lambda x : 'series' in x, [q.get() for _ in threads])
return [s for d in data for s in d['series']]
def parse(self, value):
return self.post(series_id=value)

def query_df(self, *series_ids, **kwargs):
"""Serialize query results to DataFrame."""
data = self.query(*series_ids, **kwargs)
def query(self, *series_ids):
"""Chunks series_ids into groups of 100 to send to a consumer."""
for i, C in enumerate(chunk(series_ids, 100)):
self.queue.put(';'.join(C))
self.queue.join() # Wait for all queries to complete
output = [self.results.get() for _ in range(i + 1)] # Deplete results
return [s for d in output for s in d["series"] if 'series' in d]

def query_df(self, *series_ids):
data = self.query(*series_ids)
o = []
for d in data:
df = pd.DataFrame(d.pop('data'), columns=['period', 'value'])
# This is better for pandas >= 0.16.0
# o.append(df.assign(**d))
# This is better for pandas >= 0.16.0 : o.append(df.assign(**d))
# However, do it this way for pandas < 0.16.0 support
for k, v in d.items():
df[k] = v
o.append(df)
return pd.concat(o, ignore_index=True)
if o:
return pd.concat(o, ignore_index=True)
else:
return pd.DataFrame()


class Geoset(BaseQuery):
"""
Expand All @@ -122,8 +144,8 @@ class Geoset(BaseQuery):
endpoint = "geoset/"

def query(self, geoset_id, *regions, **kwargs):
data = self.get(geoset_id = geoset_id, regions = ','.join(regions),
**kwargs)
data = self.get(geoset_id=geoset_id, regions=','.join(regions),
**kwargs)
return data['geoset']

def query_df(self, geoset_id, *regions, **kwargs):
Expand All @@ -139,11 +161,13 @@ def query_df(self, geoset_id, *regions, **kwargs):
df[k] = v
return df


class Relation(BaseQuery):
"""Not implemented for now, this does not appear to be a valid endpoint."""
# endpoint = "relation/"
pass


class Category(BaseQuery):
"""
category_id: optional, unique numerical id of the category to fetch. If missing, the API's root category is fetched.
Expand All @@ -152,32 +176,23 @@ class Category(BaseQuery):
endpoint = "category/"

def query(self, category_id=None, **kwargs):
return self.get(category_id = category_id, **kwargs)['category']
return self.get(category_id=category_id, **kwargs)['category']


class SeriesCategory(BaseQuery):

endpoint = "series/categories/"

def parse(self, q, series_id, **kwargs):
"""Thread target to put results on a queue."""
# IMPROVEMENT : calculate dry-run url len
if len(series_id) + len(self.base_url) >= 1800:
q.put(self.post(series_id = series_id, **kwargs))
else:
q.put(self.get(series_id=series_id, **kwargs))

def query(self, *series_ids, **kwargs):
q = Queue(maxsize=1000)
threads = []
for c in chunk(list(set(series_ids)), 100):
t = Thread(target = self.parse,
args = (q, ';'.join(c)), kwargs=kwargs)
t.daemon = True
t.start()
threads.append(t)
map(lambda x : x.join(), threads)
data = filter(lambda x : 'series_categories' in x, [q.get() for _ in threads])
return [s for d in data for s in d['series_categories']]
def parse(self, value):
return self.post(series_id=value)

def query(self, *series_ids):
for i, C in enumerate(chunk(series_ids, 100)):
self.queue.put(';'.join(C))
self.queue.join()
output = [self.results.get() for _ in range(i + 1)] # Deplete results
k = "series_categories"
return [s for d in output for s in d[k] if k in d]

def query_df(self, *series_ids, **kwargs):
results = self.query(*series_ids, **kwargs)
Expand All @@ -193,39 +208,28 @@ class Updates(BaseQuery):

endpoint = "updates/"

def parse(self, q, **kwargs):
q.put(self.get(**kwargs))

def query(self, category_id = None, rows=50, firstrow=0, deep = False):
def parse(self, page_params):
return self.get(**page_params)

poll = self.get(category_id = category_id,
deep=deep,
rows=1,
firstrow=0) # Check number of available rows
def query(self, category_id=None, rows=50, firstrow=0, deep=False):
poll = self.get(category_id=category_id,
deep=deep,
rows=1,
firstrow=0) # Check number of available rows
n = rows or poll['data']['rows_available']
pages = int(np.ceil(n / 10000.))
params = {"category_id" : category_id,
"deep" : deep,
"rows" : 10000}
q = Queue(maxsize=1000)
threads = []
for page in range(pages):
params = {"category_id": category_id, "deep": deep, "rows": 10000}
for page in range(int(np.ceil(n / 10000.))):
page_params = dict(params)
first = page*10000
page_params.update({"firstrow" : first})
first = page * 10000
page_params.update({"firstrow": first})
if first + 10000 > rows:
rows = rows - first
page_params.update({"rows" : rows})
t = Thread(target = self.parse,
args = (q, ), kwargs=page_params)
t.daemon = True
t.start()
threads.append(t)
map(lambda x : x.join(), threads)
data = filter(
lambda x : ('updates' in x and isinstance(x['updates'][0], dict)),
[q.get() for _ in threads])
return [s for d in data for s in d['updates']]
page_params.update({"rows": rows})
self.queue.put(page_params)
self.queue.join()
output = [self.results.get() for _ in range(page + 1)]
k = "updates"
return [s for d in output for s in d[k] if k in d]

def query_df(self, *args, **kwargs):
return pd.DataFrame(self.query(*args, **kwargs))
Expand All @@ -235,11 +239,11 @@ class Search(BaseQuery):

endpoint = "search/"

def parse(self, q, *args, **kwargs):
q.put(self.get(*args, **kwargs)['response']['docs'])
def parse(self, val):
return self.get(**val)['response']['docs']

def query(self, search_term, search_value,
rows_per_page=10, page_num=1, **kwargs):
rows_per_page=10, page_num=1):
"""
search_term : str,
one of ["series_id", "name", "last_updated"]
Expand All @@ -261,47 +265,43 @@ def query(self, search_term, search_value,
"""
t, v = self.clean_search_params(search_term, search_value)
if (rows_per_page == 0) or (rows_per_page == 'all'):
r = self.get(search_term = t, search_value = v, page_num=1, rows_per_page = 1, **kwargs)
r = self.get(search_term=t, search_value=v,
page_num=1, rows_per_page=1)
total = r['response']['numFound']
# Chunk by 5000's for reliability/reasonable speed
chunksize = 7500
pages = int(np.ceil(total/float(chunksize)))
params = {"search_term" : t, "search_value" : v}
params.update(kwargs)
threads = []
q = Queue(maxsize=1000)
for page in range(0, pages):
pages = int(np.ceil(total / float(chunksize)))
params = {"search_term": t, "search_value": v}
for page in range(pages):
rows = chunksize
first = page*rows
first = page * rows
if first + rows > total:
rows = total - first
page_params = dict(params)
page_params.update({"page_num" : page, "rows_per_page" : rows})
t = Thread(target=self.parse, args=(q,), kwargs=page_params)
t.daemon = True
t.start()
threads.append(t)
map(lambda x : x.join(), threads)
results = [q.get() for _ in threads]
results = [i for chunk in results for i in chunk]

else: # limit search results, paginate elsewhere
r = self.get(search_term = t, search_value = v, page_num=page_num, rows_per_page = rows_per_page, **kwargs)
page_params.update({"page_num": page, "rows_per_page": rows})
self.queue.put(page_params)
self.queue.join()
output = [self.results.get() for _ in range(pages)]
results = [i for chunk in output for i in chunk]
else: # limit search results, paginate elsewhere
r = self.get(search_term=t, search_value=v,
page_num=page_num, rows_per_page=rows_per_page)
results = r['response']['docs']
return results

def clean_search_params(self, search_term, search_value):
t, v = search_term, search_value # Alias for convenience
assert t in ["series_id", "name", "last_updated"], "Invalid search term"
t, v = search_term, search_value # Alias for convenience
assert t in ["series_id", "name",
"last_updated"], "Invalid search term"

if t in ["series_id", "name"]:
if not re.match(re.escape('^"{}"$'.format(v.strip('"'))), v):
v = '"{}"'.format(search_value)
else: # search_term == "last_updated"
if len(list(v)) == 2: # Assume it's a tuple, list, or iterable
else: # search_term == "last_updated"
if len(list(v)) == 2: # Assume it's a tuple, list, or iterable
daterange = map(pd.to_datetime, list(v))
strfdaterange = map(
lambda x : x.strftime('%Y-%m-%dT%H:%M:%SZ'),
lambda x: x.strftime('%Y-%m-%dT%H:%M:%SZ'),
daterange)
v = "[{}]".format(" TO ".join(strfdaterange))
# else : assume user has read documentation and is feeding in
Expand Down

0 comments on commit 7b77aca

Please sign in to comment.