Skip to content

Commit 15f66fa

Browse files
committed
add support for colored inputs (i. e. three input channel)
Added cifar10 for testing colored inputs.
1 parent d4cf856 commit 15f66fa

File tree

5 files changed

+101
-27
lines changed

5 files changed

+101
-27
lines changed

playground/04_custom_toplevel.py

+29-13
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
from typing import Dict, List, Optional
55

66
from bitstring import Bits
7+
import larq as lq
78
import numpy as np
9+
import tensorflow as tf
810

911
# TODO: copied from test_utils.general
1012
def to_fixedint(number: int, bitwidth: int, is_unsigned: bool = True):
@@ -382,18 +384,18 @@ def __init__(
382384
self.output_classes = output_classes
383385
self.output_bitwidth = output_bitwidth
384386
self.previous_layer_info = {
385-
"name": "in",
387+
"name": "in_deserialized",
386388
"channel": input_channel,
387389
"bitwidth": input_bitwidth,
388390
"width": image_width,
389391
"height": image_height,
390392
}
391393

392-
self.input_data_signal = Parameter(
394+
self.input_data_signal_deserialized = Parameter(
393395
f"slv_data_{self.previous_layer_info['name']}",
394-
f"std_logic_vector(8 - 1 downto 0)",
396+
f"std_logic_vector(C_INPUT_CHANNEL * C_INPUT_CHANNEL_BITWIDTH - 1 downto 0)",
395397
)
396-
self.input_control_signal = Parameter(
398+
self.input_control_signal_deserialized = Parameter(
397399
f"sl_valid_{self.previous_layer_info['name']}", "std_logic"
398400
)
399401

@@ -419,7 +421,7 @@ def __init__(
419421
isl_clk : in std_logic;
420422
isl_start : in std_logic;
421423
isl_valid : in std_logic;
422-
islv_data : in std_logic_vector(C_INPUT_CHANNEL * C_INPUT_CHANNEL_BITWIDTH - 1 downto 0);
424+
islv_data : in std_logic_vector(C_INPUT_CHANNEL_BITWIDTH - 1 downto 0);
423425
oslv_data : out std_logic_vector(C_OUTPUT_CHANNEL_BITWIDTH - 1 downto 0);
424426
osl_valid : out std_logic;
425427
osl_finish : out std_logic
@@ -441,14 +443,32 @@ def to_vhdl(self):
441443
declarations.append("-- input signals")
442444
declarations.append(
443445
parameter_to_vhdl(
444-
"signal", [self.input_data_signal, self.input_control_signal]
446+
"signal",
447+
[
448+
self.input_data_signal_deserialized,
449+
self.input_control_signal_deserialized,
450+
],
445451
)
446452
)
447453
declarations.append("")
448454

449455
# connect input signals
450-
implementation.append(f"{self.input_control_signal.name} <= isl_valid;")
451-
implementation.append(f"{self.input_data_signal.name} <= islv_data;")
456+
implementation.append(
457+
f"""
458+
i_deserializer : entity util.deserializer
459+
generic map (
460+
C_DATA_COUNT => C_INPUT_CHANNEL,
461+
C_DATA_BITWIDTH => C_INPUT_CHANNEL_BITWIDTH
462+
)
463+
port map (
464+
isl_clk => isl_clk,
465+
isl_valid => isl_valid,
466+
islv_data => islv_data,
467+
oslv_data => {self.input_data_signal_deserialized.name},
468+
osl_valid => {self.input_control_signal_deserialized.name}
469+
);
470+
"""
471+
)
452472

453473
# parse the bnn
454474
for layer in self.layers:
@@ -576,15 +596,11 @@ def get_stride(strides):
576596

577597

578598
def bnn_from_larq(path: str) -> Bnn:
579-
import larq as lq
580-
import tensorflow as tf
581-
582599
model = tf.keras.models.load_model(path)
583600
lq.models.summary(model)
584601

585-
input_channel = 1
602+
input_channel = model.input.shape[-1]
586603
input_channel_bitwidth = 8
587-
output_channel = 8
588604
output_channel_bitwidth = 8
589605
bnn = Bnn(
590606
*model.input.shape[1:], # h x w x ch

playground/05_intro_modified.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,22 @@
22
import tensorflow as tf
33

44
# for resizing
5-
import cv2
65
import numpy as np
76

8-
(train_images, train_labels), (
9-
test_images,
10-
test_labels,
11-
) = tf.keras.datasets.mnist.load_data()
7+
dataset = "mnist"
8+
if dataset == "mnist":
9+
(
10+
(train_images, train_labels),
11+
(test_images, test_labels,),
12+
) = tf.keras.datasets.mnist.load_data()
1213

13-
# reshape the inputs
14-
# train_images = np.stack([cv2.resize(img, (22, 22)) for img in train_images])
15-
train_images = train_images.reshape((60000, 28, 28, 1))
16-
# test_images = np.stack([cv2.resize(img, (22, 22)) for img in test_images])
17-
test_images = test_images.reshape((10000, 28, 28, 1))
14+
train_images = train_images.reshape((60000, 28, 28, 1))
15+
test_images = test_images.reshape((10000, 28, 28, 1))
16+
else:
17+
(
18+
(train_images, train_labels),
19+
(test_images, test_labels,),
20+
) = tf.keras.datasets.cifar10.load_data()
1821

1922
# All quantized layers except the first will use the same options
2023
kwargs = dict(
@@ -50,6 +53,10 @@
5053
model.add(lq.layers.QuantConv2D(64, (1, 1), use_bias=False, **kwargs))
5154
model.add(tf.keras.layers.BatchNormalization(scale=False))
5255
# model.add(tf.keras.layers.Dropout(0.2))
56+
57+
if dataset == "mnist":
58+
model.add(lq.layers.QuantConv2D(128, (1, 1), use_bias=False, **kwargs))
59+
model.add(tf.keras.layers.BatchNormalization(scale=False))
5360
if True:
5461
model.add(lq.layers.QuantConv2D(10, (1, 1), use_bias=False, **kwargs))
5562
model.add(lq.layers.tf.keras.layers.GlobalAveragePooling2D())

sim/test_bnn.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@
2424
async def run_test(dut):
2525
height = dut.C_INPUT_HEIGHT.value.integer
2626
width = dut.C_INPUT_WIDTH.value.integer
27-
input_image = np.random.randint(0, 255, (1, height, width, 1), dtype=np.uint8)
28-
# input_image = np.full((1, height, width, 1), 0, dtype=np.uint8)
27+
channel = dut.C_INPUT_CHANNEL.value.integer
28+
input_image = np.random.randint(0, 255, (1, height, width, channel), dtype=np.uint8)
29+
# input_image = np.full((1, height, width, channel), 0, dtype=np.uint8)
2930

3031
# TODO: How to disable the custom gradient warning?
3132
model = tf.keras.models.load_model("../../models/test")

sim/test_bnn_uart.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@
1717
async def run_test(dut):
1818
height = dut.i_bnn.C_INPUT_HEIGHT.value.integer
1919
width = dut.i_bnn.C_INPUT_WIDTH.value.integer
20+
channel = dut.i_bnn.C_INPUT_CHANNEL.value.integer
2021
classes = dut.i_bnn.C_OUTPUT_CHANNEL.value.integer
2122

22-
# input_image = np.random.randint(0, 255, (1, height, width, 1), dtype=np.uint8)
23-
input_image = np.full((1, height, width, 1), 0, dtype=np.uint8)
23+
# input_image = np.random.randint(0, 255, (1, height, width, channel), dtype=np.uint8)
24+
input_image = np.full((1, height, width, channel), 0, dtype=np.uint8)
2425

2526
# initialize the test
2627
clock_period = 40 # ns

src/util/deserializer.vhd

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
library ieee;
2+
use ieee.std_logic_1164.all;
3+
4+
entity deserializer is
5+
generic (
6+
C_DATA_COUNT : integer := 4;
7+
C_DATA_BITWIDTH : integer := 8
8+
);
9+
port (
10+
isl_clk : in std_logic;
11+
isl_valid : in std_logic;
12+
islv_data : in std_logic_vector(C_DATA_BITWIDTH - 1 downto 0);
13+
oslv_data : out std_logic_vector(C_DATA_COUNT * C_DATA_BITWIDTH - 1 downto 0);
14+
osl_valid : out std_logic
15+
);
16+
end entity deserializer;
17+
18+
architecture rtl of deserializer is
19+
20+
signal int_input_count : integer range 0 to C_DATA_COUNT - 1 := 0;
21+
signal slv_data_out : std_logic_vector(oslv_data'range) := (others => '0');
22+
signal sl_valid_out : std_logic := '0';
23+
24+
begin
25+
26+
proc_deserializer : process (isl_clk) is
27+
begin
28+
29+
if (rising_edge(isl_clk)) then
30+
sl_valid_out <= '0';
31+
32+
if (isl_valid = '1') then
33+
slv_data_out <= islv_data & slv_data_out(slv_data_out'high downto C_DATA_BITWIDTH);
34+
35+
if (int_input_count /= C_DATA_COUNT - 1) then
36+
int_input_count <= int_input_count + 1;
37+
else
38+
int_input_count <= 0;
39+
sl_valid_out <= '1';
40+
end if;
41+
end if;
42+
end if;
43+
44+
end process proc_deserializer;
45+
46+
osl_valid <= sl_valid_out;
47+
oslv_data <= slv_data_out;
48+
49+
end architecture rtl;

0 commit comments

Comments
 (0)