Skip to content

Commit

Permalink
Remove config dataclass in param_test. prepare for residual add
Browse files Browse the repository at this point in the history
  • Loading branch information
Aba committed Nov 10, 2023
1 parent 957700c commit deb6c09
Showing 1 changed file with 13 additions and 37 deletions.
50 changes: 13 additions & 37 deletions test/py/param_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,18 +155,6 @@ def compile(c):
return c


@dataclass
class Config:
K : int
CO: int
is_bias: bool
act_q: str
strides: (int,int) = (1,1)
pool_d: dict = None
flatten: bool = False
dense: bool = False


@pytest.mark.parametrize("COMPILE", list(product_dict(
X_BITS = [4 ],
K_BITS = [4 ],
Expand All @@ -190,37 +178,25 @@ class Config:
)))
def test_dnn_engine(COMPILE):
c = make_compile_params(COMPILE)

input_shape = (8,18,18,3) # (XN, XH, XW, CI)
model_config = [
Config(11, 8, True , f'quantized_relu({c.X_BITS},0,negative_slope=0)', (2,1), pool_d={'type':'max', 'size':(3,4), 'strides':(2,3), 'padding':'same', 'act_str':f'quantized_bits({c.X_BITS},0,False,False,1)'}),
Config(1 , 8, False, f'quantized_bits({c.X_BITS},0,False,False,1)'),
Config(7 , 8, True , f'quantized_bits({c.X_BITS},0,False,True,1)'),
Config(5 , 8, False, f'quantized_relu({c.X_BITS},0,negative_slope=0.125)'),
Config(3 , 24, True , f'quantized_relu({c.X_BITS},0,negative_slope=0)'),
Config(1 , 10 , False, f'quantized_relu({c.X_BITS},0,negative_slope=0.125)', flatten=True),
Config(1 , 10, True , f'quantized_relu({c.X_BITS},0,negative_slope=0.125)', dense= True),
]

'''
Build Model
'''
assert c.X_BITS in [1,2,4,8] and c.K_BITS in [1,2,4,8], "X_BITS and K_BITS should be in [1,2,4,8]"
assert c.B_BITS in [8,16,32], "B_BITS should be in [8,16,32]"
xq, kq, bq = f'quantized_bits({c.X_BITS},0,False,True,1)', f'quantized_bits({c.K_BITS},0,False,True,1)', f'quantized_bits({c.B_BITS},0,False,True,1)'
inp = {'bits':c.X_BITS, 'frac':c.X_BITS-1}
inp = {'bits':c.X_BITS, 'frac':c.X_BITS-1}

'''
Build Model
'''
input_shape = (1,18,18,3) # (XN, XH, XW, CI)
x = x_in = Input(input_shape[1:], name='input')
x = QActivation(xq)(x)
for i, g in enumerate(model_config):
if g.dense:
d = {'core': {'type':'dense', 'units':g.CO, 'kernel_quantizer':kq, 'bias_quantizer':bq, 'use_bias':g.is_bias, 'act_str':g.act_q}}
else:
d = {
'core': {'type':'conv', 'filters':g.CO, 'kernel_size':(g.K,g.K), 'strides':g.strides, 'padding':'same', 'kernel_quantizer':kq, 'bias_quantizer':bq, 'use_bias':g.is_bias, 'act_str':g.act_q},
'pool': g.pool_d, 'flatten':g.flatten,
}
x = Bundle(**d)(x)

x = Bundle( core= {'type':'conv' , 'filters':8 , 'kernel_size':(11,11), 'strides':(2,1), 'padding':'same', 'kernel_quantizer':kq, 'bias_quantizer':bq, 'use_bias':True , 'act_str':f'quantized_relu({c.X_BITS},0,negative_slope=0)' }, pool= {'type':'max', 'size':(3,4), 'strides':(2,3), 'padding':'same', 'act_str':f'quantized_bits({c.X_BITS},0,False,False,1)'})(x)
x = Bundle( core= {'type':'conv' , 'filters':8 , 'kernel_size':( 1, 1), 'strides':(1,1), 'padding':'same', 'kernel_quantizer':kq, 'bias_quantizer':bq, 'use_bias':True , 'act_str':f'quantized_bits({c.X_BITS},0,False,False,1)' },)(x)
x = Bundle( core= {'type':'conv' , 'filters':8 , 'kernel_size':( 7, 7), 'strides':(1,1), 'padding':'same', 'kernel_quantizer':kq, 'bias_quantizer':bq, 'use_bias':False, 'act_str':f'quantized_bits({c.X_BITS},0,False,True,1)' },)(x)
x = Bundle( core= {'type':'conv' , 'filters':8 , 'kernel_size':( 5, 5), 'strides':(1,1), 'padding':'same', 'kernel_quantizer':kq, 'bias_quantizer':bq, 'use_bias':True , 'act_str':f'quantized_relu({c.X_BITS},0,negative_slope=0.125)'},)(x)
x = Bundle( core= {'type':'conv' , 'filters':24, 'kernel_size':( 3, 3), 'strides':(1,1), 'padding':'same', 'kernel_quantizer':kq, 'bias_quantizer':bq, 'use_bias':True , 'act_str':f'quantized_relu({c.X_BITS},0,negative_slope=0)' },)(x)
x = Bundle( core= {'type':'conv' , 'filters':10, 'kernel_size':( 1, 1), 'strides':(1,1), 'padding':'same', 'kernel_quantizer':kq, 'bias_quantizer':bq, 'use_bias':True , 'act_str':f'quantized_relu({c.X_BITS},0,negative_slope=0.125)'}, flatten= True)(x)
x = Bundle( core= {'type':'dense', 'units' :10, 'kernel_quantizer':kq, 'bias_quantizer':bq, 'use_bias':True , 'act_str':f'quantized_relu({c.X_BITS},0,negative_slope=0.125)'})(x)

model = Model(inputs=x_in, outputs=x)

Expand Down

0 comments on commit deb6c09

Please sign in to comment.