Skip to content

Commit

Permalink
Cu 8692y0n0a Add mypy (#9)
Browse files Browse the repository at this point in the history
* CU-8692y0n0a Add dev-requirements

* CU-8692y0n0a Install dev-requirements in workflow

* CU-8692y0n0a Add mypy config

* CU-8692y0n0a Some typing fixes in cogstack and credentials modules

* CU-8692y0n0a Some typing fixes in python modules regarding model creation and training

* CU-8692y0n0a Fix method name typo and a typing fix for model running module

* CU-8692y0n0a Some typing fixes in mct evaluate module

* CU-8692y0n0a Add mypy to workflow

* CU-8692y0n0a Add type ignore comment for pandas chained assignment in cdb creation

* CU-8692y0n0a Add type ignore comment for pandas chained assignment in cdb creation (UMLS)
  • Loading branch information
mart-r authored Nov 27, 2023
1 parent fc8f7e3 commit 593c5c2
Show file tree
Hide file tree
Showing 11 changed files with 44 additions and 28 deletions.
7 changes: 7 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r requirements-dev.txt
- name: Typing
# run mypy on all tracked non-test python modules
# and use explicit package base since the project
# is not set up as a python package
run: |
python -m mypy `git ls-tree --full-tree --name-only -r HEAD | grep ".py$" | grep -v "tests/"` --explicit-package-bases --follow-imports=normal
- name: Test
run: |
python -m unittest discover
Expand Down
19 changes: 8 additions & 11 deletions cogstack.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import getpass
from typing import Dict, List, Any, Optional
from typing import Dict, List, Any, Optional, Iterable, Tuple
import elasticsearch
import elasticsearch.helpers
import pandas as pd
Expand All @@ -22,7 +22,8 @@ class CogStack(object):
password (str, optional): The password to use when connecting to Elasticsearch. If not provided, the user will be prompted to enter a password.
api (bool, optional): A boolean value indicating whether to use API keys or basic authentication to connect to Elasticsearch. Defaults to False (i.e., use basic authentication).
"""
def __init__(self, hosts: List, username: str=None, password: str=None, api=False):
def __init__(self, hosts: List, username: Optional[str] = None, password: Optional[str] = None,
api: bool = False):

if api:
api_username, api_password = self._check_auth_details(username, password)
Expand All @@ -36,7 +37,7 @@ def __init__(self, hosts: List, username: str=None, password: str=None, api=Fals
verify_certs=False)


def _check_auth_details(self, username=None, password=None):
def _check_auth_details(self, username=None, password=None) -> Tuple[str, str]:
"""
Prompt the user for a username and password if the values are not provided as function arguments.
Expand All @@ -53,7 +54,7 @@ def _check_auth_details(self, username=None, password=None):
password = getpass.getpass("Password: ")
return username, password

def get_docs_generator(self, index: List, query: Dict, es_gen_size: int=800, request_timeout: int=300):
def get_docs_generator(self, index: List, query: Dict, es_gen_size: int=800, request_timeout: Optional[int] = 300):
"""
Retrieve a generator object that can be used to iterate through documents in an Elasticsearch index.
Expand Down Expand Up @@ -95,12 +96,8 @@ def cogstack2df(self, query: Dict, index: str, column_headers=None, es_gen_size:
size=es_gen_size,
request_timeout=request_timeout)
temp_results = []
results = self.elastic.count(index=index, query=query['query'], request_timeout=300)
if show_progress:
_tqdm = tqdm
else:
_tqdm = _no_progress_bar
for hit in _tqdm(docs_generator, total=results['count'], desc="CogStack retrieved..."):
results = self.elastic.count(index=index, query=query['query'], request_timeout=300) # type: ignore
for hit in tqdm(docs_generator, total=results['count'], desc="CogStack retrieved...", disable=not show_progress):
row = dict()
row['_index'] = hit['_index']
row['_id'] = hit['_id']
Expand Down Expand Up @@ -144,6 +141,6 @@ def list_chunker(user_list: List[Any], n: int) -> List[List[Any]]:
return [user_list[i:i+n] for i in range(0, len(user_list), n)]


def _no_progress_bar(iterable: list, **kwargs):
def _no_progress_bar(iterable: Iterable, **kwargs):
return iterable

3 changes: 2 additions & 1 deletion credentials.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import List
# CogStack login details
## Any questions on what these details are please contact your local CogStack administrator.

hosts = [] # This is a list of your CogStack ElasticSearch instances.
hosts: List[str] = [] # This is a list of your CogStack ElasticSearch instances.

## These are your login details (either via http_auth or API) Should be in str format
username = None
Expand Down
4 changes: 2 additions & 2 deletions medcat/1_create_model/create_cdb/create_cdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from medcat.config import Config
from medcat.cdb_maker import CDBMaker

pd.options.mode.chained_assignment = None
pd.options.mode.chained_assignment = None # type: ignore

# relative to file path
_FILE_DIR = os.path.dirname(__file__)
Expand Down Expand Up @@ -41,7 +41,7 @@
print("Cleaning acronyms...")
for i, row in csv[(~csv['acronym'].isnull()) & (csv['name_status'] == 'A')][['name', 'acronym']].iterrows():
if row['name'][0:len(row['acronym'])] == row['acronym']:
csv['name'].iloc[i] = row['acronym']
csv['name'].iloc[i] = row['acronym'] # type: ignore

print("acronyms complete")

Expand Down
2 changes: 1 addition & 1 deletion medcat/1_create_model/create_cdb/create_umls_cdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from medcat.config import Config
from medcat.cdb_maker import CDBMaker

pd.options.mode.chained_assignment = None
pd.options.mode.chained_assignment = None # type: ignore

# relative to file path
_FILE_DIR = os.path.dirname(__file__)
Expand Down
4 changes: 2 additions & 2 deletions medcat/2_train_model/1_unsupervised_training/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ def split(self, in_file: str):
1, self.opts, self.split_identifier, header=line)
continue
buffer = buffer.process_or_write(line_nr, line)
if len(buffer.lines) > 1: # if there's more than just a header
buffer.save() # saver remaining
if buffer and len(buffer.lines) > 1: # if there's more than just a header
buffer.save() # saver remaining


def split_file(in_file: str, nr_of_lines: int, out_file_format: str) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
medcat_logger.addHandler(fh)

###Change parameters here###
cogstack_indices = [] # list of cogstack indexes here
cogstack_indices: list = [] # list of cogstack indexes here
text_columns = ['body_analysed'] # list of all text containing fields
# relative to file path
_FILE_DIR = os.path.dirname(__file__)
Expand All @@ -24,8 +24,8 @@
model_pack_name = ''
output_modelpack_name = '' # name of modelpack to save

cs = CogStack(hosts, api_username=api_username, api_password=api_password, api=True)
df = cs.DataFrame(index=cogstack_indices, columns=text_columns)
cs = CogStack(hosts, username=username, password=password, api=True)
df = cs.DataFrame(index=cogstack_indices, columns=text_columns) # type: ignore

cat = CAT.load_model_pack(model_pack_path+model_pack_name)
cat.cdb.print_stats()
Expand Down
4 changes: 2 additions & 2 deletions medcat/3_run_model/run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
cogstack_indices = [''] # Enter your list of relevant cogstack indices here

# log size of indices
df = cs.DataFrame(index=cogstack_indices, columns=['body_analysed'])
df = cs.DataFrame(index=cogstack_indices, columns=['body_analysed']) # type: ignore
medcat_logger.warning(f'The index size is {df.shape[0]}!')
del df

Expand All @@ -44,7 +44,7 @@

data_dir = 'data'
ann_folder_path = os.path.join(base_path, data_dir, f'annotated_docs')
if not os.path.exisits(ann_folder_path):
if not os.path.exists(ann_folder_path):
os.makedirs(ann_folder_path)

medcat_logger.warning(f'Anntotations will be saved here: {ann_folder_path}')
Expand Down
13 changes: 7 additions & 6 deletions medcat/evaluate_mct_export/mct_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def user_stats(self, by_user: bool = True):
return data
return data[['user', 'count', 'date']]

def plot_user_stats(self, save_fig: bool = False, save_fig_filename: str = False):
def plot_user_stats(self, save_fig: bool = False, save_fig_filename: str = ''):
"""
Plot annotator user stats against time.
An alternative method of saving the file is: plot_user_stats().write_image("path/filename.png")
Expand Down Expand Up @@ -352,7 +352,7 @@ def meta_anns_concept_summary(self):
meta_anns_df.insert(1, 'concept_name', meta_anns_df['cui'].map(self.cat.cdb.cui2preferred_name))
return meta_anns_df

def generate_report(self, path: str = 'mct_report.xlsx', meta_ann=False, concept_filter: List = None):
def generate_report(self, path: str = 'mct_report.xlsx', meta_ann=False, concept_filter: Optional[List] = None):
"""
:param path: Outfile path
:param meta_ann: Include Meta_annotation evaluation in the summary as well
Expand All @@ -362,8 +362,9 @@ def generate_report(self, path: str = 'mct_report.xlsx', meta_ann=False, concept
if concept_filter:
with pd.ExcelWriter(path, engine_kwargs={'options': {'remove_timezone': True}}) as writer:
print('Generating report...')
df = pd.DataFrame.from_dict([self.cat.get_model_card(as_dict=True)]).T.reset_index(drop=False)
df.columns = ['MCT report', f'Generated on {date.today().strftime("%Y/%m/%d")}']
# array-like is allowed by documentation but not by typing
df = pd.DataFrame.from_dict([self.cat.get_model_card(as_dict=True)]).T.reset_index(drop=False) # type: ignore
df.columns = ['MCT report', f'Generated on {date.today().strftime("%Y/%m/%d")}'] # type: ignore
df = pd.concat([df, pd.DataFrame([['MCT Custom filter', concept_filter]], columns=df.columns)],
ignore_index = True)
df.to_excel(writer, index=False, sheet_name='medcat_model_card')
Expand All @@ -390,8 +391,8 @@ def generate_report(self, path: str = 'mct_report.xlsx', meta_ann=False, concept
else:
with pd.ExcelWriter(path, engine_kwargs={'options': {'remove_timezone': True}}) as writer:
print('Generating report...')
df = pd.DataFrame.from_dict([self.cat.get_model_card(as_dict=True)]).T.reset_index(drop=False)
df.columns = ['MCT report', f'Generated on {date.today().strftime("%Y/%m/%d")}']
df = pd.DataFrame.from_dict([self.cat.get_model_card(as_dict=True)]).T.reset_index(drop=False) # type: ignore
df.columns = ['MCT report', f'Generated on {date.today().strftime("%Y/%m/%d")}'] # type: ignore
df.to_excel(writer, index=False, sheet_name='medcat_model_card')
self.user_stats().to_excel(writer, index=False, sheet_name='user_stats')
#self.plot_user_stats().to_excel(writer, index=False, sheet_name='user_stats_plot')
Expand Down
5 changes: 5 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Global options:

[mypy]
ignore_missing_imports = True
allow_redefinition = True
5 changes: 5 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
mypy
pandas-stubs
types-tqdm
types-requests
types-regex

0 comments on commit 593c5c2

Please sign in to comment.