Skip to content

Commit 0112bf8

Browse files
author
Bradley Augstein
committed
#151 - reduce bass class code, simplify column and spin logic
1 parent f92f0d3 commit 0112bf8

File tree

2 files changed

+63
-92
lines changed

2 files changed

+63
-92
lines changed

heracles/_mapping.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def transform(
177177
msg = f"unknown field name: {k}"
178178
raise ValueError(msg)
179179

180-
out[k, i] = field.mapper_or_error.transform(m)
180+
out[k, i] = field.mapper.transform(m)
181181

182182
if progress:
183183
subtask.remove()

heracles/fields.py

Lines changed: 62 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -64,95 +64,45 @@ class Field(metaclass=ABCMeta):
6464
# definition of required and optional columns
6565
__ncol: tuple[int, int]
6666

67-
def __init_subclass__(cls, *, spin: int | None = None) -> None:
67+
''' def __init_subclass__(cls, *, spin: int | None = None) -> None:
6868
"""Initialise spin weight of field subclasses."""
6969
super().__init_subclass__()
70-
if spin is not None:
71-
cls.__spin = spin
72-
uses = cls.uses
73-
if uses is None:
74-
uses = ()
75-
elif isinstance(uses, str):
76-
uses = (uses,)
77-
ncol = len(uses)
78-
nopt = 0
79-
for u in uses[::-1]:
80-
if u.startswith("[") and u.endswith("]"):
81-
nopt += 1
82-
else:
83-
break
84-
cls.__ncol = (ncol - nopt, ncol)
70+
cls.__spin = spin'''
71+
8572

8673
def __init__(
8774
self,
8875
mapper: Mapper | None,
89-
*columns: str,
76+
*columns: str,
77+
weight: str | None,
9078
mask: str | None = None,
9179
) -> None:
9280
"""Initialise the field."""
9381
super().__init__()
9482
self.__mapper = mapper
95-
self.__columns = self._init_columns(*columns) if columns else None
83+
self.__columns = columns if columns else None
84+
self.__weight = weight if weight else None
9685
self.__mask = mask
97-
98-
@classmethod
99-
def _init_columns(cls, *columns: str) -> Columns:
100-
"""Initialise the given set of columns for a specific field
101-
subclass."""
102-
nmin, nmax = cls.__ncol
103-
if not nmin <= len(columns) <= nmax:
104-
uses = cls.uses
105-
if uses is None:
106-
uses = ()
107-
if isinstance(uses, str):
108-
uses = (uses,)
109-
count = f"{nmin}"
110-
if nmax != nmin:
111-
count += f" to {nmax}"
112-
msg = f"field of type '{cls.__name__}' accepts {count} columns"
113-
if uses:
114-
msg += " (" + ", ".join(uses) + ")"
115-
msg += f", received {len(columns)}"
116-
raise ValueError(msg)
117-
return columns + (None,) * (nmax - len(columns))
86+
self.__spin = 0
11887

11988
@property
12089
def mapper(self) -> Mapper | None:
12190
"""Return the mapper used by this field."""
12291
return self.__mapper
123-
92+
12493
@property
125-
def mapper_or_error(self) -> Mapper:
126-
"""Return the mapper used by this field, or raise a :class:`ValueError`
127-
if not set."""
128-
if self.__mapper is None:
129-
msg = "no mapper for field"
130-
raise ValueError(msg)
131-
return self.__mapper
94+
def weight(self) -> str | None:
95+
"""Return the mapper used by this field."""
96+
return self.__weight
13297

13398
@property
13499
def columns(self) -> Columns | None:
135100
"""Return the catalogue columns used by this field."""
136101
return self.__columns
137102

138-
@property
139-
def columns_or_error(self) -> Columns:
140-
"""Return the catalogue columns used by this field, or raise a
141-
:class:`ValueError` if not set."""
142-
if self.__columns is None:
143-
msg = "no columns for field"
144-
raise ValueError(msg)
145-
return self.__columns
146-
147103
@property
148104
def spin(self) -> int:
149-
"""Spin weight of field."""
150-
spin = self.__spin
151-
if spin is None:
152-
clsname = self.__class__.__name__
153-
msg = f"field of type '{clsname}' has undefined spin weight"
154-
raise ValueError(msg)
155-
return spin
105+
return self.__spin
156106

157107
@property
158108
def mask(self) -> str | None:
@@ -168,7 +118,13 @@ async def __call__(
168118
) -> ArrayLike:
169119
"""Implementation for mapping a catalogue."""
170120
...
171-
121+
122+
def CheckColumns(self, *expected):
123+
if(len(expected)!=len(self.columns)):
124+
error = "Column error. Expected " + str(len(expected)) + " columns"
125+
error += "with a format " + str(expected) + ". Received " + str(self.columns)
126+
raise ValueError(error)
127+
172128

173129
async def _pages(
174130
catalog: Catalog,
@@ -190,26 +146,25 @@ async def _pages(
190146
await coroutines.sleep()
191147

192148

193-
class Positions(Field, spin=0):
149+
class Positions(Field):
194150
"""Field of positions in a catalogue.
195151
196152
Can produce both overdensity maps and number count maps, depending
197153
on the ``overdensity`` property.
198154
199155
"""
200156

201-
uses = "longitude", "latitude"
202-
203157
def __init__(
204158
self,
205159
mapper: Mapper | None,
206160
*columns: str,
161+
weight: str | None,
207162
overdensity: bool = True,
208163
nbar: float | None = None,
209164
mask: str | None = None,
210165
) -> None:
211166
"""Create a position field."""
212-
super().__init__(mapper, *columns, mask=mask)
167+
super().__init__(mapper, *columns,weight=weight, mask=mask)
213168
self.__overdensity = overdensity
214169
self.__nbar = nbar
215170

@@ -241,11 +196,15 @@ async def __call__(
241196
msg = "cannot compute density contrast: no visibility in catalog"
242197
raise ValueError(msg)
243198

244-
# get mapper
245-
mapper = self.mapper_or_error
199+
#get mapper
200+
mapper = self.mapper
246201

247202
# get catalogue column definition
248-
col = self.columns_or_error
203+
col = self.columns
204+
self.CheckColumns("longitude", "latitude")
205+
206+
#if(len(col)!=2):
207+
# raise ValueError("Expect 2 colummns, longitude and latitude")
249208

250209
# position map
251210
pos = mapper.create(spin=self.spin)
@@ -259,7 +218,7 @@ async def __call__(
259218
lon, lat = page.get(*col)
260219
w = np.ones(page.size)
261220

262-
mapper.map_values(lon, lat, pos, w)
221+
self.mapper.map_values(lon, lat, pos, w)
263222

264223
ngal += page.size
265224

@@ -307,11 +266,9 @@ async def __call__(
307266
return pos
308267

309268

310-
class ScalarField(Field, spin=0):
269+
class ScalarField(Field):
311270
"""Field of real scalar values in a catalogue."""
312271

313-
uses = "longitude", "latitude", "value", "[weight]"
314-
315272
async def __call__(
316273
self,
317274
catalog: Catalog,
@@ -321,11 +278,13 @@ async def __call__(
321278
"""Map real values from catalogue to HEALPix map."""
322279

323280
# get mapper
324-
mapper = self.mapper_or_error
281+
mapper = self.mapper
325282

326283
# get the column definition of the catalogue
327-
*col, wcol = self.columns_or_error
328-
284+
col = self.columns
285+
self.CheckColumns(self, "longitude", "latitude", "value")
286+
287+
wcol = self.__weight
329288
# scalar field map
330289
val = mapper.create(spin=self.spin)
331290

@@ -373,16 +332,14 @@ async def __call__(
373332
return val
374333

375334

376-
class ComplexField(Field, spin=0):
335+
class ComplexField(Field):
377336
"""Field of complex values in a catalogue.
378337
379338
The :class:`ComplexField` class has zero spin weight, while
380339
subclasses such as :class:`Spin2Field` have non-zero spin weight.
381340
382341
"""
383342

384-
uses = "longitude", "latitude", "real", "imag", "[weight]"
385-
386343
async def __call__(
387344
self,
388345
catalog: Catalog,
@@ -392,10 +349,14 @@ async def __call__(
392349
"""Map complex values from catalogue to HEALPix map."""
393350

394351
# get mapper
395-
mapper = self.mapper_or_error
352+
mapper = self.mapper
396353

397354
# get the column definition of the catalogue
398-
*col, wcol = self.columns_or_error
355+
col = self.columns
356+
357+
self.CheckColumns(self, "longitude", "latitude", "real", "imag")
358+
359+
wcol = self.weight
399360

400361
# complex map with real and imaginary part
401362
val = mapper.create(2, spin=self.spin)
@@ -443,7 +404,7 @@ async def __call__(
443404
return val
444405

445406

446-
class Visibility(Field, spin=0):
407+
class Visibility(Field):
447408
"""Copy visibility map from catalogue at given resolution."""
448409

449410
async def __call__(
@@ -455,7 +416,7 @@ async def __call__(
455416
"""Create a visibility map from the given catalogue."""
456417

457418
# get mapper
458-
mapper = self.mapper_or_error
419+
mapper = self.mapper
459420

460421
# make sure that catalogue has a visibility
461422
visibility = catalog.visibility
@@ -479,11 +440,9 @@ async def __call__(
479440
return out
480441

481442

482-
class Weights(Field, spin=0):
443+
class Weights(Field):
483444
"""Field of weight values from a catalogue."""
484445

485-
uses = "longitude", "latitude", "[weight]"
486-
487446
async def __call__(
488447
self,
489448
catalog: Catalog,
@@ -493,10 +452,12 @@ async def __call__(
493452
"""Map catalogue weights."""
494453

495454
# get mapper
496-
mapper = self.mapper_or_error
455+
mapper = self.mapper
497456

498457
# get the columns for this field
499-
*col, wcol = self.columns_or_error
458+
col = self.columns
459+
self.CheckColumns(self, "longitude", "latitude")
460+
wcol = self.weight
500461

501462
# weight map
502463
wht = mapper.create(spin=self.spin)
@@ -543,8 +504,18 @@ async def __call__(
543504
return wht
544505

545506

546-
class Spin2Field(ComplexField, spin=2):
507+
class Spin2Field(ComplexField):
547508
"""Spin-2 complex field."""
509+
def __init__(
510+
self,
511+
mapper: Mapper | None,
512+
*columns: str,
513+
weight: str | None,
514+
mask: str | None = None,
515+
) -> None:
516+
"""Initialise the field."""
517+
super().__init__(mapper, *columns,weight=weight, mask=mask)
518+
self.__spin=2
548519

549520

550521
Shears = Spin2Field

0 commit comments

Comments
 (0)