Skip to content

Commit

Permalink
fix torch installation and bug fixes (#79)
Browse files Browse the repository at this point in the history
* fix torch version for mac os

* fix error message on data_generator

* fix charting in ACD

* fix ccg columns and summary options
  • Loading branch information
dayesouza authored Nov 6, 2024
1 parent e851f25 commit cc7b696
Show file tree
Hide file tree
Showing 7 changed files with 367 additions and 68 deletions.
104 changes: 89 additions & 15 deletions app/workflows/anonymize_case_data/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,8 @@ def create(sv: ds_variables.SessionVariables, workflow: None):
options=chart_type_options,
index=chart_type_options.index(
st.session_state[f"{workflow}_chart_type"]
if f"{workflow}_chart_type" in st.session_state
else chart_type_options[0]
),
)
if chart_type == "Top attributes":
Expand All @@ -368,19 +370,49 @@ def create(sv: ds_variables.SessionVariables, workflow: None):
f"{workflow}_chart_individual_configuration"
]
st.markdown("##### Configure top attributes chart")
print(
'chart_individual_configuration["show_attributes"]',
chart_individual_configuration["show_attributes"],
)
default_attrs = st.session_state[
f"{workflow}_chart_individual_configuration"
]["show_attributes"]
# check if default attrs are in sdf.columns()
default_attrs_existing = [
attr
for attr in default_attrs
if attr in sdf.columns.to_numpy()
]
show_attributes = st.multiselect(
"Types of top attributes to show",
options=sdf.columns.to_numpy(),
default=chart_individual_configuration["show_attributes"],
default=(
chart_individual_configuration["show_attributes"]
if (len(default_attrs_existing) == len(default_attrs))
else []
),
)
num_values = st.number_input(
"Number of top attribute values to show",
value=chart_individual_configuration["num_values"],
)
chart_individual_configuration["show_attributes"] = (

if (
show_attributes
)
chart_individual_configuration["num_values"] = num_values
!= st.session_state[
f"{workflow}_chart_individual_configuration"
]["show_attributes"]
):
st.session_state[
f"{workflow}_chart_individual_configuration"
]["show_attributes"] = show_attributes
st.rerun()

if num_values != chart_individual_configuration["num_values"]:
st.session_state[
f"{workflow}_chart_individual_configuration"
]["num_values"] = num_values
st.rerun()

chart, chart_df = acd.get_bar_chart_fig(
selection=selection,
Expand Down Expand Up @@ -411,18 +443,33 @@ def create(sv: ds_variables.SessionVariables, workflow: None):
options=time_options,
index=time_options.index(
chart_individual_configuration["time_attribute"]
if chart_individual_configuration["time_attribute"]
in sdf.columns.to_numpy()
else None,
),
)
series_attributes = st.multiselect(
"Series attributes",
options=list(sdf.columns.to_numpy())
)
chart_individual_configuration["time_attribute"] = (

if (
time_attribute
)
chart_individual_configuration["series_attributes"] = (
series_attributes
)
!= chart_individual_configuration["time_attribute"]
):
st.session_state[
f"{workflow}_chart_individual_configuration"
]["time_attribute"] = time_attribute
st.rerun()

if (
time_attribute
!= chart_individual_configuration["series_attributes"]
):
st.session_state[
f"{workflow}_chart_individual_configuration"
]["series_attributes"] = time_attribute
st.rerun()

if time_attribute != "" and len(series_attributes) > 0:
chart, chart_df = acd.get_line_chart_fig(
Expand Down Expand Up @@ -459,31 +506,58 @@ def create(sv: ds_variables.SessionVariables, workflow: None):
options=attribute_type_options,
index=attribute_type_options.index(
chart_individual_configuration["source_attribute"]
if chart_individual_configuration["source_attribute"]
in attribute_type_options
else None,
),
)
target_attribute = st.selectbox(
"Target/destination attribute type",
options=attribute_type_options,
index=attribute_type_options.index(
chart_individual_configuration["target_attribute"]
if chart_individual_configuration["target_attribute"]
in attribute_type_options
else None,
),
)
highlight_attribute = st.selectbox(
"Highlight attribute",
options=highlight_options,
index=highlight_options.index(
chart_individual_configuration["highlight_attribute"]
if chart_individual_configuration["highlight_attribute"]
in highlight_options
else None,
),
)
chart_individual_configuration["source_attribute"] = (

if (
source_attribute
)
chart_individual_configuration["target_attribute"] = (
!= chart_individual_configuration["source_attribute"]
):
st.session_state[
f"{workflow}_chart_individual_configuration"
]["source_attribute"] = source_attribute
st.rerun()

if (
target_attribute
)
chart_individual_configuration["highlight_attribute"] = (
!= chart_individual_configuration["target_attribute"]
):
st.session_state[
f"{workflow}_chart_individual_configuration"
]["target_attribute"] = target_attribute
st.rerun()

if (
highlight_attribute
)
!= chart_individual_configuration["highlight_attribute"]
):
st.session_state[
f"{workflow}_chart_individual_configuration"
]["highlight_attribute"] = highlight_attribute
st.rerun()

if source_attribute != "" and target_attribute != "":
# export_df = compute_flow_query(selection, sv.anonymize_synthetic_df.value, adf, att_separator, val_separator, data_schema, source_attribute, target_attribute, highlight_attribute)
Expand Down
43 changes: 36 additions & 7 deletions app/workflows/compare_case_groups/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,28 +70,57 @@ def create(sv: gn_variables.SessionVariables, workflow=None):
st.markdown("##### Define summary model")
sorted_cols = sorted(sv.case_groups_final_df.value.columns)

default_groups_existing = [
attr for attr in sv.case_groups_groups.value if attr in sorted_cols
]
default_atts_existing = [
attr
for attr in sv.case_groups_aggregates.value
if attr in sorted_cols
]

case_group_options = ccg.get_filter_options(
pl.from_pandas(sv.case_groups_final_df.value)
)
default_filters_existing = [
attr
for attr in sv.case_groups_filters.value
if attr in case_group_options
]

groups = st.multiselect(
"Compare groups of records with different combinations of these attributes:",
sorted_cols,
default=sv.case_groups_groups.value,
default=sv.case_groups_groups.value
if (
len(default_groups_existing) == len(sv.case_groups_groups.value)
)
else [],
)
aggregates = st.multiselect(
"Using counts of these attributes:",
sorted_cols,
default=sv.case_groups_aggregates.value,
default=sv.case_groups_aggregates.value
if (len(default_atts_existing) == len(sv.case_groups_groups.value))
else [],
)
temporal_options = ["", *sorted_cols]
temporal = st.selectbox(
"Across windows of this temporal/ordinal attribute (optional):",
temporal_options,
index=temporal_options.index(sv.case_groups_temporal.value),
index=temporal_options.index(sv.case_groups_temporal.value)
if sv.case_groups_temporal.value in temporal_options
else 0,
)
filters = st.multiselect(
"After filtering to records matching these values (optional):",
ccg.get_filter_options(
pl.from_pandas(sv.case_groups_final_df.value)
),
default=sv.case_groups_filters.value,
case_group_options,
default=sv.case_groups_filters.value
if (
len(default_filters_existing)
== len(sv.case_groups_filters.value)
)
else [],
)

create = st.button(
Expand Down
4 changes: 2 additions & 2 deletions intelligence_toolkit/AI/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@

import tiktoken

from intelligence_toolkit.AI.base_chat import BaseChat
from intelligence_toolkit.AI.client import OpenAIClient
from intelligence_toolkit.AI.defaults import DEFAULT_ENCODING, DEFAULT_REPORT_BATCH_SIZE
from intelligence_toolkit.AI.validation_prompt import GROUNDEDNESS_PROMPT
from intelligence_toolkit.helpers.progress_batch_callback import ProgressBatchCallback
from intelligence_toolkit.AI.base_chat import BaseChat
from intelligence_toolkit.AI.client import OpenAIClient

log = logging.getLogger(__name__)

Expand Down
7 changes: 4 additions & 3 deletions intelligence_toolkit/compare_case_groups/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def get_filter_options(self, input_df: pl.DataFrame) -> list[str]:
return sorted_atts

def _select_columns_ranked_df(self, ranked_df: pl.DataFrame) -> None:
columns = [g.lower() for g in self.groups]
columns = self.groups
default_columns = [
"group_count",
"group_rank",
Expand All @@ -100,7 +100,6 @@ def _select_columns_ranked_df(self, ranked_df: pl.DataFrame) -> None:
f"{self.temporal}_window_delta",
]
)

self.model_df = ranked_df.select(columns)

def create_data_summary(
Expand Down Expand Up @@ -136,7 +135,9 @@ def create_data_summary(
if temporal:
window_df = create_window_df(groups, temporal, aggregates, self.filtered_df)

temporal_atts = sorted(self.model_df[temporal].cast(pl.Utf8).unique())
temporal_atts = sorted(
self.model_df[temporal].cast(pl.Utf8).unique().drop_nulls()
)

temporal_df = build_temporal_data(
window_df, groups, temporal_atts, temporal
Expand Down
8 changes: 6 additions & 2 deletions intelligence_toolkit/generate_mock_data/data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import intelligence_toolkit.AI.utils as utils
import intelligence_toolkit.generate_mock_data.prompts as prompts
import intelligence_toolkit.generate_mock_data.schema_builder as schema_builder
import intelligence_toolkit.AI.utils as utils
from intelligence_toolkit.helpers.progress_batch_callback import ProgressBatchCallback


Expand Down Expand Up @@ -64,7 +63,12 @@ async def generate_data(

for new_object in new_objects:
print(new_object)
new_object_json = loads(new_object)
try:
new_object_json = loads(new_object)
except Exception as e:
msg = f"AI did not return a valid JSON response. Please try again. {e}"
raise ValueError(msg) from e

generated_objects.append(new_object_json)
current_object_json, conflicts = merge_json_objects(
current_object_json, new_object_json
Expand Down
Loading

0 comments on commit cc7b696

Please sign in to comment.