Skip to content

Commit 1df0931

Browse files
authored
Use lists in import_ids dataframe indexing to clean up warnings (#105)
* Added Brendan's tests - passing with warnings * Applied Brendan's fix - fixes warnings, but breaks in default-argument case * fix default issue + some cleanup
1 parent ba15dff commit 1df0931

File tree

2 files changed

+45
-42
lines changed

2 files changed

+45
-42
lines changed

nfl_data_py/__init__.py

Lines changed: 22 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
name = 'nfl_data_py'
22

3-
import datetime
43
import os
54
import logging
6-
from concurrent.futures import ThreadPoolExecutor, as_completed
5+
import datetime
76
from warnings import warn
7+
from typing import Iterable
8+
from concurrent.futures import ThreadPoolExecutor, as_completed
89

9-
import appdirs
1010
import numpy
1111
import pandas
12-
from typing import Iterable
12+
import appdirs
1313

1414
# module level doc string
1515
__doc__ = """
@@ -735,52 +735,32 @@ def import_ids(columns=None, ids=None):
735735
"""Import mapping table of ids for most major data providers
736736
737737
Args:
738-
columns (List[str]): list of columns to return
739-
ids (List[str]): list of specific ids to return
738+
columns (Iterable[str]): list of columns to return
739+
ids (Iterable[str]): list of specific ids to return
740740
741741
Returns:
742742
DataFrame
743743
"""
744-
745-
# create list of id options
746-
avail_ids = ['mfl_id', 'sportradar_id', 'fantasypros_id', 'gsis_id', 'pff_id',
747-
'sleeper_id', 'nfl_id', 'espn_id', 'yahoo_id', 'fleaflicker_id',
748-
'cbs_id', 'rotowire_id', 'rotoworld_id', 'ktc_id', 'pfr_id',
749-
'cfbref_id', 'stats_id', 'stats_global_id', 'fantasy_data_id']
750-
avail_sites = [x[:-3] for x in avail_ids]
751-
752-
# check variable types
753-
if columns is None:
754-
columns = []
755-
756-
if ids is None:
757-
ids = []
758744

759-
if not isinstance(columns, list):
760-
raise ValueError('columns variable must be list.')
761-
762-
if not isinstance(ids, list):
763-
raise ValueError('ids variable must be list.')
764-
765-
# confirm id is in table
766-
if False in [x in avail_sites for x in ids]:
767-
raise ValueError('ids variable can only contain ' + ', '.join(avail_sites))
745+
columns = columns or []
746+
if not isinstance(columns, Iterable):
747+
raise ValueError('columns argument must be a list.')
748+
749+
ids = ids or []
750+
if not isinstance(ids, Iterable):
751+
raise ValueError('ids argument must be a list.')
768752

769-
# import data
770-
df = pandas.read_csv(r'https://raw.githubusercontent.com/dynastyprocess/data/master/files/db_playerids.csv')
753+
df = pandas.read_csv("https://raw.githubusercontent.com/dynastyprocess/data/master/files/db_playerids.csv")
771754

772-
rem_cols = [x for x in df.columns if x not in avail_ids]
773-
tgt_ids = [x + '_id' for x in ids]
774-
775-
# filter df to just specified columns
776-
if len(columns) > 0 and len(ids) > 0:
777-
df = df[set(tgt_ids + columns)]
778-
elif len(columns) > 0 and len(ids) == 0:
779-
df = df[set(avail_ids + columns)]
780-
elif len(columns) == 0 and len(ids) > 0:
781-
df = df[set(tgt_ids + rem_cols)]
755+
id_cols = [c for c in df.columns if c.endswith('_id')]
756+
non_id_cols = [c for c in df.columns if not c.endswith('_id')]
782757

783-
return df
758+
# filter df to just specified ids + columns
759+
ret_ids = [x + '_id' for x in ids] or id_cols
760+
ret_cols = columns or non_id_cols
761+
ret_columns = list(set([*ret_ids, *ret_cols]))
762+
763+
return df[ret_columns]
784764

785765

786766
def import_contracts():

nfl_data_py/tests/nfl_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,29 @@ def test_is_df_with_data(self):
167167
s = nfl.import_ids()
168168
self.assertEqual(True, isinstance(s, pd.DataFrame))
169169
self.assertTrue(len(s) > 0)
170+
171+
def test_import_using_ids(self):
172+
ids = ["espn", "yahoo", "gsis"]
173+
s = nfl.import_ids(ids=ids)
174+
self.assertTrue(all([f"{id}_id" in s.columns for id in ids]))
175+
176+
def test_import_using_columns(self):
177+
ret_columns = ["name", "birthdate", "college"]
178+
not_ret_columns = ["draft_year", "db_season", "team"]
179+
s = nfl.import_ids(columns=ret_columns)
180+
self.assertTrue(all([column in s.columns for column in ret_columns]))
181+
self.assertTrue(all([column not in s.columns for column in not_ret_columns]))
182+
183+
def test_import_using_ids_and_columns(self):
184+
ret_ids = ["espn", "yahoo", "gsis"]
185+
ret_columns = ["name", "birthdate", "college"]
186+
not_ret_ids = ["cfbref_id", "pff_id", "prf_id"]
187+
not_ret_columns = ["draft_year", "db_season", "team"]
188+
s = nfl.import_ids(columns=ret_columns, ids=ret_ids)
189+
self.assertTrue(all([column in s.columns for column in ret_columns]))
190+
self.assertTrue(all([column not in s.columns for column in not_ret_columns]))
191+
self.assertTrue(all([f"{id}_id" in s.columns for id in ret_ids]))
192+
self.assertTrue(all([f"{id}_id" not in s.columns for id in not_ret_ids]))
170193

171194
class test_ngs(TestCase):
172195
def test_is_df_with_data(self):

0 commit comments

Comments
 (0)