Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable the C4 ruff lints about comprehensions #7527

Merged
merged 8 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/benchmarks/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def track_marginal_mixture_model_ess(self, init):
_, step = pm.init_nuts(
init=init, chains=self.chains, progressbar=False, random_seed=np.arange(self.chains)
)
start = [{k: v for k, v in start.items()} for _ in range(self.chains)]
start = [dict(start) for _ in range(self.chains)]
t0 = time.time()
idata = pm.sample(
draws=self.draws,
Expand Down
2 changes: 1 addition & 1 deletion docs/source/learn/core_notebooks/GLM_linear.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@
"# add noise\n",
"y = true_regression_line + rng.normal(scale=0.5, size=size)\n",
"\n",
"data = pd.DataFrame(dict(x=x, y=y))"
"data = pd.DataFrame({\"x\": x, \"y\": y})"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions docs/source/learn/core_notebooks/pymc_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3372,13 +3372,13 @@
" test_scores[\"score\"].values,\n",
" kind=\"hist\",\n",
" color=\"C1\",\n",
" hist_kwargs=dict(alpha=0.6),\n",
" hist_kwargs={\"alpha\": 0.6},\n",
" label=\"observed\",\n",
")\n",
"az.plot_dist(\n",
" prior_samples.prior_predictive[\"scores\"],\n",
" kind=\"hist\",\n",
" hist_kwargs=dict(alpha=0.6),\n",
" hist_kwargs={\"alpha\": 0.6},\n",
" label=\"simulated\",\n",
")\n",
"plt.xticks(rotation=45);"
Expand Down
2 changes: 1 addition & 1 deletion pymc/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def nchains(self) -> int:

@property
def chains(self) -> list[int]:
return list(sorted(self._straces.keys()))
return sorted(self._straces.keys())

@property
def report(self) -> SamplerReport:
Expand Down
2 changes: 1 addition & 1 deletion pymc/backends/mcbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(

def record(self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]]):
values = self._point_fn(draw)
value_dict = {n: v for n, v in zip(self.varnames, values)}
value_dict = dict(zip(self.varnames, values))
stats_dict = self._statsbj.map(stats)
# Apply pickling to objects stats
for fname in self._statsbj.object_stats.keys():
Expand Down
2 changes: 1 addition & 1 deletion pymc/backends/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def setup(self, draws, chain, sampler_vars=None) -> None:
if self._stats is None:
self._stats = []
for sampler in sampler_vars:
data: dict[str, np.ndarray] = dict()
data: dict[str, np.ndarray] = {}
self._stats.append(data)
for varname, dtype in sampler.items():
data[varname] = np.zeros(draws, dtype=dtype)
Expand Down
28 changes: 14 additions & 14 deletions pymc/distributions/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,17 +162,17 @@ def rv_op(
rv_type = type(
class_name,
(CustomDistRV,),
dict(
name=class_name,
inplace=False,
ndim_supp=ndim_supp,
ndims_params=ndims_params,
signature=signature,
dtype=dtype,
_print_name=(class_name, f"\\operatorname{{{class_name}}}"),
{
"name": class_name,
"inplace": False,
"ndim_supp": ndim_supp,
"ndims_params": ndims_params,
"signature": signature,
"dtype": dtype,
"_print_name": (class_name, f"\\operatorname{{{class_name}}}"),
# Specific to CustomDist
_random_fn=random,
),
"_random_fn": random,
},
)

# Dispatch custom methods
Expand Down Expand Up @@ -278,10 +278,10 @@ def rv_op(
class_name,
(CustomSymbolicDistRV,),
# If logp is not provided, we try to infer it from the dist graph
dict(
inline_logprob=logp is None,
_print_name=(class_name, f"\\operatorname{{{class_name}}}"),
),
{
"inline_logprob": logp is None,
"_print_name": (class_name, f"\\operatorname{{{class_name}}}"),
},
)

# Dispatch custom methods
Expand Down
2 changes: 1 addition & 1 deletion pymc/distributions/dist_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def log_normal(x, mean, **kwargs):
rho = kwargs.get("rho")
tau = kwargs.get("tau")
eps = kwargs.get("eps", 0.0)
check = sum(map(lambda a: a is not None, [sigma, w, rho, tau]))
check = sum(a is not None for a in [sigma, w, rho, tau])
if check > 1:
raise ValueError("more than one required kwarg is passed")
if check == 0:
Expand Down
2 changes: 1 addition & 1 deletion pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ def inline_symbolic_random_variable(fgraph, node):
"""
op = node.op
if op.inline_logprob:
return clone_replace(op.inner_outputs, {u: v for u, v in zip(op.inner_inputs, node.inputs)})
return clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs)))


# Registered before pre-canonicalization which happens at position=-10
Expand Down
4 changes: 2 additions & 2 deletions pymc/distributions/shape_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ def to_tuple(shape):
returned. If it is array-like, tuple(shape) is returned.
"""
if shape is None:
return tuple()
return ()
temp = np.atleast_1d(shape)
if temp.size == 0:
return tuple()
return ()
else:
return tuple(temp)

Expand Down
24 changes: 12 additions & 12 deletions pymc/distributions/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,18 +241,18 @@ def rv_op(
sim_op = type(
class_name,
(SimulatorRV,),
dict(
name=class_name,
ndim_supp=ndim_supp,
ndims_params=ndims_params,
signature=signature,
dtype=dtype,
inplace=False,
fn=fn,
_distance=distance,
_sum_stat=sum_stat,
epsilon=epsilon,
),
{
"name": class_name,
"ndim_supp": ndim_supp,
"ndims_params": ndims_params,
"signature": signature,
"dtype": dtype,
"inplace": False,
"fn": fn,
"_distance": distance,
"_sum_stat": sum_stat,
"epsilon": epsilon,
},
)()
return sim_op(*params, **kwargs)

Expand Down
6 changes: 3 additions & 3 deletions pymc/distributions/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,9 @@ def random_walk_logp(op, values, *inputs, **kwargs):
(value,) = values
# Recreate RV and obtain inner graph
rv_node = op.make_node(*inputs)
rv = clone_replace(
op.inner_outputs, replace={u: v for u, v in zip(op.inner_inputs, rv_node.inputs)}
)[op.default_output]
rv = clone_replace(op.inner_outputs, replace=dict(zip(op.inner_inputs, rv_node.inputs)))[
op.default_output
]
# Obtain logp of the inner graph and collapse steps dimension
return logp(rv, value).sum(axis=-1)

Expand Down
4 changes: 1 addition & 3 deletions pymc/func_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,7 @@ def find_constrained_prior(
)

# save optimal parameters
opt_params = {
param_name: param_value for param_name, param_value in zip(init_guess.keys(), opt.x)
}
opt_params = dict(zip(init_guess.keys(), opt.x))
if fixed_params is not None:
opt_params.update(fixed_params)
return opt_params
2 changes: 1 addition & 1 deletion pymc/logprob/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ def split_valued_ifelse(fgraph, node):
replacements = {first_valued_out: first_valued_ifelse}

if remaining_vars:
first_ifelse_ancestors = set(a for a in ancestors((first_then, first_else)) if a.owner)
first_ifelse_ancestors = {a for a in ancestors((first_then, first_else)) if a.owner}
remaining_thens = [then_out for (then_out, _, _, _) in remaining_vars]
remaininng_elses = [else_out for (_, else_out, _, _) in remaining_vars]
if set(remaining_thens + remaininng_elses) & first_ifelse_ancestors:
Expand Down
2 changes: 1 addition & 1 deletion pymc/logprob/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def construct_ir_fgraph(
# Replace valued RVs by ValuedVar Ops so that rewrites are aware of conditioning points
# We use clones of the value variables so that they are not affected by rewrites
cloned_values = tuple(v.clone() for v in rv_values.values())
ir_rv_values = {rv: value for rv, value in zip(fgraph.outputs, cloned_values)}
ir_rv_values = dict(zip(fgraph.outputs, cloned_values))

replacements = tuple((rv, valued_rv(rv, value)) for rv, value in ir_rv_values.items())
toposort_replace(fgraph, replacements, reverse=True)
Expand Down
2 changes: 1 addition & 1 deletion pymc/logprob/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def logprob_join(op, values, axis, *base_rvs, **kwargs):
axis=axis,
)

base_rvs_to_split_values = {base_rv: value for base_rv, value in zip(base_rvs, split_values)}
base_rvs_to_split_values = dict(zip(base_rvs, split_values))
logps = [
_logprob_helper(base_var, split_value)
for base_var, split_value in base_rvs_to_split_values.items()
Expand Down
6 changes: 3 additions & 3 deletions pymc/sampling/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def sample_prior_predictive(
)

# All model variables have a name, but mypy does not know this
_log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore
_log.info(f"Sampling: {sorted(volatile_basic_rvs, key=lambda var: var.name)}") # type: ignore
values = zip(*(sampler_fn() for i in range(draws)))

data = {k: np.stack(v) for k, v in zip(names, values)}
Expand All @@ -460,7 +460,7 @@ def sample_prior_predictive(

if not return_inferencedata:
return prior
ikwargs: dict[str, Any] = dict(model=model)
ikwargs: dict[str, Any] = {"model": model}
if idata_kwargs:
ikwargs.update(idata_kwargs)
return pm.to_inference_data(prior=prior, **ikwargs)
Expand Down Expand Up @@ -850,7 +850,7 @@ def sample_posterior_predictive(
)
sampler_fn = point_wrapper(_sampler_fn)
# All model variables have a name, but mypy does not know this
_log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore
_log.info(f"Sampling: {sorted(volatile_basic_rvs, key=lambda var: var.name)}") # type: ignore
ppc_trace_t = _DefaultTrace(samples)

progress = CustomProgress(
Expand Down
4 changes: 2 additions & 2 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def _print_step_hierarchy(s: Step, level: int = 0) -> None:

def all_continuous(vars):
"""Check that vars not include discrete variables"""
if any([(var.dtype in discrete_types) for var in vars]):
if any((var.dtype in discrete_types) for var in vars):
return False
else:
return True
Expand Down Expand Up @@ -947,7 +947,7 @@ def _sample_return(

idata = None
if compute_convergence_checks or return_inferencedata:
ikwargs: dict[str, Any] = dict(model=model, save_warmup=not discard_tuned_samples)
ikwargs: dict[str, Any] = {"model": model, "save_warmup": not discard_tuned_samples}
ikwargs.update(idata_kwargs)
idata = pm.to_inference_data(mtrace, **ikwargs)

Expand Down
2 changes: 1 addition & 1 deletion pymc/smc/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def _posterior_to_trace(self, chain=0) -> NDArray:
var_samples = np.round(var_samples).astype(var.dtype)
value.append(var_samples.reshape(shape))
size += new_size
strace.record(point={k: v for k, v in zip(varnames, value)})
strace.record(point=dict(zip(varnames, value)))
return strace


Expand Down
2 changes: 1 addition & 1 deletion pymc/smc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def _save_sample_stats(
library=pymc,
)

ikwargs: dict[str, Any] = dict(model=model)
ikwargs: dict[str, Any] = {"model": model}
if idata_kwargs is not None:
ikwargs.update(idata_kwargs)
idata = to_inference_data(trace, **ikwargs)
Expand Down
18 changes: 9 additions & 9 deletions pymc/step_methods/metropolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def __init__(
self.accepted_sum = np.zeros(dims, dtype=int)

# remember initial settings before tuning so they can be reset
self._untuned_settings = dict(scaling=self.scaling, steps_until_tune=tune_interval)
self._untuned_settings = {"scaling": self.scaling, "steps_until_tune": tune_interval}

# TODO: This is not being used when compiling the logp function!
self.mode = mode
Expand Down Expand Up @@ -422,7 +422,7 @@ def __init__(self, vars, scaling=1.0, tune=True, tune_interval=100, model=None,

vars = get_value_vars_from_user_vars(vars, model)

if not all([v.dtype in pm.discrete_types for v in vars]):
if not all(v.dtype in pm.discrete_types for v in vars):
raise ValueError("All variables must be Bernoulli for BinaryMetropolis")

super().__init__(vars, [model.compile_logp()], rng=rng)
Expand Down Expand Up @@ -541,7 +541,7 @@ def __init__(self, vars, order="random", transit_p=0.8, model=None, rng=None):
self.shuffle_dims = False
self.order = order

if not all([v.dtype in pm.discrete_types for v in vars]):
if not all(v.dtype in pm.discrete_types for v in vars):
raise ValueError("All variables must be binary for BinaryGibbsMetropolis")

super().__init__(vars, [model.compile_logp()], rng=rng)
Expand Down Expand Up @@ -1063,12 +1063,12 @@ def __init__(
# cache local history for the Z-proposals
self._history: list[np.ndarray] = []
# remember initial settings before tuning so they can be reset
self._untuned_settings = dict(
scaling=self.scaling,
lamb=self.lamb,
steps_until_tune=tune_interval,
accepted=self.accepted,
)
self._untuned_settings = {
"scaling": self.scaling,
"lamb": self.lamb,
"steps_until_tune": tune_interval,
"accepted": self.accepted,
}

self.mode = mode

Expand Down
14 changes: 7 additions & 7 deletions pymc/variational/approximations.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
of the method
"""

__param_spec__ = dict(mu=("d",), rho=("d",))
__param_spec__ = {"mu": ("d",), "rho": ("d",)}
short_name = "mean_field"
alias_names = frozenset(["mf"])

Expand Down Expand Up @@ -122,7 +122,7 @@
main drawback of the method is computational cost.
"""

__param_spec__ = dict(mu=("d",), L_tril=("int(d * (d + 1) / 2)",))
__param_spec__ = {"mu": ("d",), "L_tril": ("int(d * (d + 1) / 2)",)}
short_name = "full_rank"
alias_names = frozenset(["fr"])

Expand Down Expand Up @@ -193,14 +193,14 @@
"""

has_logq = False
__param_spec__ = dict(histogram=("s", "d"))
__param_spec__ = {"histogram": ("s", "d")}
short_name = "empirical"

@pytensor.config.change_flags(compute_test_value="off")
def __init_group__(self, group):
super().__init_group__(group)
self._check_trace()
if not self._check_user_params(spec_kw=dict(s=-1)):
if not self._check_user_params(spec_kw={"s": -1}):
self.shared_params = self.create_shared_params(
trace=self._kwargs.get("trace", None),
size=self._kwargs.get("size", None),
Expand All @@ -225,7 +225,7 @@
for j in range(len(trace)):
histogram[i] = DictToArrayBijection.map(trace.point(j, t)).data
i += 1
return dict(histogram=pytensor.shared(pm.floatX(histogram), "histogram"))
return {"histogram": pytensor.shared(pm.floatX(histogram), "histogram")}

def _check_trace(self):
trace = self._kwargs.get("trace", None)
Expand All @@ -236,7 +236,7 @@
" Please help us to refactor: https://github.com/pymc-devs/pymc/issues/5884"
)
elif trace is not None and not all(
[self.model.rvs_to_values[var].name in trace.varnames for var in self.group]
self.model.rvs_to_values[var].name in trace.varnames for var in self.group
):
raise ValueError("trace has not all free RVs in the group")

Expand Down Expand Up @@ -344,7 +344,7 @@
def __dir__(self):
d = set(super().__dir__())
d.update(self.groups[0].__dir__())
return list(sorted(d))
return sorted(d)

Check warning on line 347 in pymc/variational/approximations.py

View check run for this annotation

Codecov / codecov/patch

pymc/variational/approximations.py#L347

Added line #L347 was not covered by tests


class MeanField(SingleGroupApproximation):
Expand Down
Loading
Loading