Skip to content

Commit

Permalink
Flipped visualizer default coloring behavior, light_mode is now the d…
Browse files Browse the repository at this point in the history
…efault, change CTC pred csv to xlsx output to allow immediate conditional formatting
  • Loading branch information
Thelukepet committed Mar 28, 2024
1 parent 367e7d2 commit e04bb04
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 45 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,8 @@ python3 main.py
This will output various files into the `visualize_plots directory`:
* A PDF sheet consisting of all made visualizations for the above call
* Individual PNG and JPG files of these visualizations
* A `sample_image_preds.csv` which consist of a character prediction table for each prediction timestep. The highest probability is the character that was chosen by the model
* A `sample_image_preds.xslx` which consist of a character prediction table for
each prediction timestep. The highest probability is the character that was chosen by the model

Currently, the following visualizers are implemented:
1. **visualize_timestep_predictions**: Takes the `sample_image` and simulates the model's prediction process for each time step, the top-3 most probable characters per timestep are displayed and the "cleaned" result is shown at the bottom.
Expand All @@ -382,7 +383,7 @@ Potential future implementations:
### 3. (Optional parameters)
```bash
--do_detailed # Visualize all convolutional layers, not just the first instance of a conv layer
--light_mode # Plots and overviews are shown in light mode (instead of dark mode)
--dark_mode # Plots and overviews are shown in dark mode (instead of light mode)
--num_filters_per_row # Changes the number of filters per row in the filter activation plots (default =6)
# NOTE: increasing the num_filters_per_row requires significant computing resources, you might experience an OOM.
```
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ fpdf==1.7.2
scikit-image==0.22.0
prometheus-client==0.20.0
tf-models-official==2.14.1
xlsxwriter==3.2.0
17 changes: 9 additions & 8 deletions src/visualize/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,21 @@
pdf.add_page()

# Set color_scheme
if args.light_mode:
font_r, font_g, font_b = 0, 0, 0
else:
if args.dark_mode:
font_r, font_g, font_b = 255, 255, 255
# Draw a rectangle to fill the entire page with the default background
# color
pdf.rect(0, 0, pdf.w, pdf.h, 'F')
else:
font_r, font_g, font_b = 0, 0, 0

pdf.set_header(args.replace_header, font_r, font_g, font_b)

TS_PLOT = ("visualize_plots/timestep_prediction_plot"
+ ("_light" if args.light_mode else "_dark")
+ ("_dark" if args.dark_mode else "_light")
+ ".jpg")
ACT_PLOT = ("visualize_plots/model_new10_1channel_filters_act"
+ ("_light" if args.light_mode else "_dark")
+ ("_dark" if args.dark_mode else "_light")
+ ("_detailed" if args.do_detailed else "")
+ ".png")

Expand All @@ -58,6 +59,6 @@
+ (args.replace_header if args.replace_header
else pdf.extract_model_name(args.existing_model))
+ "_visualization_report"
+ ("_light" if args.light_mode
else "_dark")
+ ".pdf")
+ ("_dark" if args.dark_mode
else "_light")
+ ".pdf")
2 changes: 1 addition & 1 deletion src/visualize/vis_arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def get_arg_parser():
parser.add_argument('--sample_image_path', metavar='sample_image_path',
type=str, default="", help='single png to for '
'saliency plots')
parser.add_argument('--light_mode', action='store_true', default=False,
parser.add_argument('--dark_mode', action='store_true', default=False,
help='for setting the output image background + font '
'color')
parser.add_argument('--existing_model', metavar='existing_model', type=str,
Expand Down
6 changes: 3 additions & 3 deletions src/visualize/visualize_filters_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def main(args=None):
model)

# Top level plot
if not args.light_mode:
if args.dark_mode:
plt.style.use('dark_background')
fig = plt.figure(figsize=(5 * num_filters_per_row,
num_filters_per_row), dpi=200)
Expand Down Expand Up @@ -306,8 +306,8 @@ def main(args=None):
else "_" + str(model_channels) + "channels")
+ ("_filters_act" if args.sample_image_path
else "_filters")
+ ("_light" if args.light_mode
else "_dark")
+ ("_dark" if args.dark_mode
else "_light")
+ ("_detailed.png" if args.do_detailed
else ".png"))
plt.savefig(os.path.join(str(Path(__file__)
Expand Down
82 changes: 51 additions & 31 deletions src/visualize/visualize_timestep_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
from pathlib import Path
import sys
import csv
import xlsxwriter
import re
from typing import Tuple, List

Expand Down Expand Up @@ -148,11 +148,11 @@ def get_timestep_indices(model_path: str, preds: np.ndarray,
timestep_char_list_indices_top_3, step_width, pad_steps_skip


def write_ctc_table_to_csv(preds: np.ndarray,
def write_ctc_table_to_xlsx(preds: np.ndarray,
char_list: str,
index_correction: int) -> None:
"""
Write CTC (Connectionist Temporal Classification) table data to a CSV file.
Write CTC (Connectionist Temporal Classification) table data to a xlsx file.
Parameters
----------
Expand All @@ -166,8 +166,9 @@ def write_ctc_table_to_csv(preds: np.ndarray,
Notes
-----
This function takes CTC predictions in the form of a 3D array, extracts the
data, and writes it to a CSV file. It creates columns for each timestep and
includes character labels.
data, and writes it to a xlsx file. It creates columns for each timestep and
includes character labels. Conditional formatting is added to better
distinguish between low and high prediction probabilities.
Examples
--------
Expand All @@ -177,8 +178,8 @@ def write_ctc_table_to_csv(preds: np.ndarray,
... [0.2, 0.1, 0.3]]])
>>> char_list = "ABC"
>>> index_correction = 1
>>> write_ctc_table_to_csv(preds, char_list, index_correction)
# Creates a CSV file with characters and corresponding predictions for each
>>> write_ctc_table_to_xlsx(preds, char_list, index_correction)
# Creates a xlsx file with characters and corresponding predictions for each
timestep.
"""
# Iterate through each index and tensor in preds
Expand All @@ -187,31 +188,50 @@ def write_ctc_table_to_csv(preds: np.ndarray,
# Iterate through each time step in the tensor
for time_step in tensor:
# Add the time step to the row
tensor_data.append(time_step.tolist())
tensor_data.append(np.round(time_step,3).tolist())

# Create columns
columns = ["ts_" + str(i) for i in range(preds.shape[1])]
additional_chars = ['MASK', 'BLANK'] if '' in char_list else ["BLANK"]

# Adjust char_list for special characters
characters = list(char_list) + additional_chars
transposed_data = np.transpose(tensor_data)

if not os.path.isdir(Path(__file__).with_name("visualize_plots")):
os.makedirs(Path(__file__).with_name("visualize_plots"))

# Write results to a CSV file
with open(str(Path(__file__).with_name("visualize_plots"))
+ "/sample_image_preds.csv", 'w', newline="") as csvfile:
writer = csv.writer(csvfile)

# Write the header
writer.writerow(['Chars'] + columns)

# Write the rows with index and data:
for i, row in enumerate(transposed_data):
# Don't print characters[-1] if index_correction is -1
if i + index_correction > -1:
writer.writerow(
[characters[i + index_correction]] + list(map(str, row)))
# Ensure the visualize_plots directory exists
directory_path = Path(__file__).with_name("visualize_plots")
if not os.path.isdir(directory_path):
os.makedirs(directory_path)

# Setup the XLSX file
workbook = xlsxwriter.Workbook(str(directory_path / "sample_image_preds.xlsx"))
worksheet = workbook.add_worksheet()
bold = workbook.add_format({'bold': True})

# Write the header
for i, column in enumerate(['Chars'] + columns):
worksheet.write(0, i, column, bold)

# Write the rows with index and data
for i, row in enumerate(transposed_data):
if i + index_correction > -1:
# Replace "unseeable" characters for readability in XLSX
char = (characters[i + index_correction].replace(" ","SPACE")
.replace("\t", "TAB")
.replace("\n", "NEWLINE"))
worksheet.write(i + 1, 0, char, bold)
for j, cell in enumerate(row):
worksheet.write(i + 1, j + 1, cell)

# Apply conditional formatting to create a heatmap effect
# Adjust the range accordingly to your data (+1 due to header row)
worksheet.conditional_format(1, 1, len(transposed_data) + 1, len(columns), {
'type': '2_color_scale',
'min_color': "#FFFFFF", # White
'max_color': "#00FF00" # Green
})

workbook.close()


def create_timestep_plots(bordered_img: np.ndarray, index_correction: int,
Expand Down Expand Up @@ -436,10 +456,10 @@ def main(args=None):
index_correction = -1

# Set color_scheme
if args.light_mode:
background_color, font_color = [255, 255, 255], (0, 0, 0) # Light mode
else:
if args.dark_mode:
background_color, font_color = [0, 0, 0], (255, 255, 255) # Dark mode
else:
background_color, font_color = [255, 255, 255], (0, 0, 0) # Light mode

# Dynamically calculate the right pad width required to keep all text
# readable (465px width at least)
Expand All @@ -460,8 +480,8 @@ def main(args=None):
char_list, timestep_char_list_indices,
timestep_char_list_indices_top_3)

# Take character preds for sample image and create csv file
write_ctc_table_to_csv(preds, char_list, index_correction)
# Take character preds for sample image and create xlsx file
write_ctc_table_to_xlsx(preds, char_list, index_correction)

# Double size
new_width = bordered_img.shape[1] * 2
Expand All @@ -474,7 +494,7 @@ def main(args=None):
# Save the timestep plot
cv2.imwrite(str(Path(__file__).with_name("visualize_plots"))
+ "/timestep_prediction_plot"
+ ("_light" if args.light_mode else "_dark")
+ ("_dark" if args.dark_mode else "_light")
+ ".jpg",
bordered_img_resized)

Expand Down

0 comments on commit e04bb04

Please sign in to comment.