Skip to content

Commit

Permalink
Merge pull request #364 from gdsfactory/more_models
Browse files Browse the repository at this point in the history
add more sax models
  • Loading branch information
joamatab authored Mar 24, 2024
2 parents 3a3ee64 + 372d744 commit 45f2250
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 87 deletions.
9 changes: 3 additions & 6 deletions gplugins/klayout/dataprep/regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(self, gdspath, cell_name: str | None = None) -> None:
self.layout = lib.cell_by_name(cell_name) if cell_name else lib.top_cell()
self.lib = lib
self.regions = {}
self.cell = lib[lib.top_cell().cell_index()]

def __getitem__(self, layer: tuple[int, int]) -> Region:
_assert_is_layer(layer)
Expand Down Expand Up @@ -142,13 +143,9 @@ def write_gds(
else:
c.write(gdspath)

def plot(self, **kwargs):
def plot(self) -> kf.KCell:
"""Plot regions."""
gdspath = GDSDIR_TEMP / "out.gds"
self.write_gds(gdspath=gdspath, **kwargs)
gf.clear_cache()
c = gf.import_gds(gdspath)
return c.plot()
return self.cell

def get_kcell(
self, keep_original: bool = True, cellname: str = "Unnamed"
Expand Down
135 changes: 121 additions & 14 deletions gplugins/sax/models.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
from __future__ import annotations

from functools import cache
import inspect
from collections.abc import Callable, Iterable
from functools import cache, partial
from inspect import getmembers

import jax
import jax.numpy as jnp
import sax
from numpy.typing import NDArray
from sax import SDict
from sax.utils import reciprocal

nm = 1e-3

FloatArray = NDArray[jnp.floating]
Float = float | FloatArray

################
# PassThrus
################
Expand Down Expand Up @@ -170,6 +177,7 @@ def grating_coupler(
https://github.com/flaport/photontorch/blob/master/photontorch/components/gratingcouplers.py
Args:
wl: wavelength.
wl0: center wavelength.
loss: in dB.
reflection: from waveguide side.
Expand Down Expand Up @@ -295,7 +303,68 @@ def coupler_single_wavelength(*, coupling: float = 0.5) -> SDict:
)


def mmi1x2() -> SDict:
################
# MMIs
################


def _mmi_amp(
wl: Float = 1.55, wl0: Float = 1.55, fwhm: Float = 0.2, loss_dB: Float = 0.3
):
max_power = 10 ** (-abs(loss_dB) / 10)
f = 1 / wl
f0 = 1 / wl0
f1 = 1 / (wl0 + fwhm / 2)
f2 = 1 / (wl0 - fwhm / 2)
_fwhm = f2 - f1

sigma = _fwhm / (2 * jnp.sqrt(2 * jnp.log(2)))
power = jnp.exp(-((f - f0) ** 2) / (2 * sigma**2))
power = max_power * power / power.max() / 2
return jnp.sqrt(power)


def mmi1x2(
wl: Float = 1.55, wl0: Float = 1.55, fwhm: Float = 0.2, loss_dB: Float = 0.3
) -> sax.SDict:
thru = _mmi_amp(wl=wl, wl0=wl0, fwhm=fwhm, loss_dB=loss_dB)
return sax.reciprocal(
{
("o1", "o2"): thru,
("o1", "o3"): thru,
}
)


def mmi2x2(
wl: Float = 1.55,
wl0: Float = 1.55,
fwhm: Float = 0.2,
loss_dB: Float = 0.3,
shift: Float = 0.005,
) -> sax.SDict:
"""Returns 2x2 MMI model.
Args:
wl: wavelength.
wl0: center wavelength.
fwhm: full width half maximum.
loss_dB: loss in dB.
shift: wavelength shift.
"""
thru = _mmi_amp(wl=wl, wl0=wl0, fwhm=fwhm, loss_dB=loss_dB)
cross = 1j * _mmi_amp(wl=wl, wl0=wl0 + shift, fwhm=fwhm, loss_dB=loss_dB)
return sax.reciprocal(
{
("o1", "o3"): thru,
("o1", "o4"): cross,
("o2", "o3"): cross,
("o2", "o4"): thru,
}
)


def mmi1x2_ideal() -> SDict:
"""Returns an ideal 1x2 splitter."""
return reciprocal(
{
Expand All @@ -305,7 +374,7 @@ def mmi1x2() -> SDict:
)


def mmi2x2(*, coupling: float = 0.5) -> SDict:
def mmi2x2_ideal(*, coupling: float = 0.5) -> SDict:
"""Returns an ideal 2x2 splitter.
Args:
Expand All @@ -323,21 +392,59 @@ def mmi2x2(*, coupling: float = 0.5) -> SDict:
)


models = dict(
straight=straight,
bend_euler=bend,
mmi1x2=mmi1x2,
mmi2x2=mmi2x2,
attenuator=attenuator,
taper=straight,
phase_shifter=phase_shifter,
grating_coupler=grating_coupler,
coupler=coupler,
)
################
# Crossings
################


@jax.jit
def crossing(wl: Float = 1.5) -> sax.SDict:
one = jnp.ones_like(jnp.asarray(wl))
return sax.reciprocal(
{
("o1", "o3"): one,
("o2", "o4"): one,
}
)


################
# Models Dict
################
def get_models(modules) -> dict[str, Callable[..., sax.SDict]]:
"""Returns all models in a module or list of modules."""
models = {}
modules = modules if isinstance(modules, Iterable) else [modules]

for module in modules:
for t in getmembers(module):
name = t[0]
func = t[1]
if not callable(func):
continue
_func = func
while isinstance(_func, partial):
_func = _func.func
try:
sig = inspect.signature(_func)
except ValueError:
continue
if str(sig.return_annotation) in {
"sax.SDict",
"SDict",
} and not name.startswith("_"):
models[name] = func
return models


if __name__ == "__main__":
import sys

import gplugins.sax as gs

models = get_models(sys.modules[__name__])
for i in models.keys():
print(i)

gs.plot_model(grating_coupler)
# gs.plot_model(coupler)
21 changes: 16 additions & 5 deletions notebooks/klayout_dataprep.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@
"import gdsfactory as gf\n",
"from gdsfactory.generic_tech.layer_map import LAYER\n",
"\n",
"import gplugins.klayout.dataprep as dp\n",
"\n",
"gf.CONF.display_type = \"klayout\""
"import gplugins.klayout.dataprep as dp"
]
},
{
Expand Down Expand Up @@ -74,6 +72,7 @@
"d[LAYER.N] = d[\n",
" LAYER.WG\n",
"].copy() # make sure you add the copy to create a copy of the layer\n",
"d.show()\n",
"d.plot()"
]
},
Expand All @@ -93,6 +92,7 @@
"outputs": [],
"source": [
"d[LAYER.N].clear()\n",
"d.show()\n",
"d.plot()"
]
},
Expand All @@ -114,6 +114,7 @@
"outputs": [],
"source": [
"d[LAYER.SLAB90] = d[LAYER.WG] + 2 # size layer by 4 um\n",
"d.show()\n",
"d.plot()"
]
},
Expand All @@ -134,7 +135,6 @@
"metadata": {},
"outputs": [],
"source": [
"\n",
"d[LAYER.SLAB90] += 2 # size layer by 4 um\n",
"d[LAYER.SLAB90] -= 2 # size layer by 2 um\n",
"d.plot()"
Expand Down Expand Up @@ -235,7 +235,6 @@
"\n",
"gdspath = \"mzi_fill.gds\"\n",
"c.write(gdspath)\n",
"c = gf.import_gds(gdspath)\n",
"c.plot()"
]
}
Expand All @@ -249,6 +248,18 @@
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
Expand Down
74 changes: 12 additions & 62 deletions notebooks/klayout_drc.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@
"metadata": {},
"outputs": [],
"source": [
"\n",
"layer = LAYER.WG\n",
"\n",
"\n",
Expand Down Expand Up @@ -168,67 +167,6 @@
"c.show() # show in klayout\n",
"c.plot()"
]
},
{
"cell_type": "markdown",
"id": "6",
"metadata": {},
"source": [
"# Klayout connectivity checks\n",
"\n",
"You can you can to check for component overlap and unconnected pins using klayout DRC.\n",
"\n",
"\n",
"The easiest way is to write all the pins on the same layer and define the allowed pin widths.\n",
"This will check for disconnected pins or ports with width mismatch."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"\n",
"import gplugins.klayout.drc.write_connectivity as wc\n",
"\n",
"nm = 1e-3\n",
"\n",
"rules = [\n",
" wc.write_connectivity_checks(pin_widths=[0.5, 0.9, 0.45], pin_layer=LAYER.PORT)\n",
"]\n",
"script = wc.write_drc_deck_macro(rules=rules, layers=None)"
]
},
{
"cell_type": "markdown",
"id": "8",
"metadata": {},
"source": [
"You can also define the connectivity checks per section"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9",
"metadata": {},
"outputs": [],
"source": [
"connectivity_checks = [\n",
" wc.ConnectivyCheck(cross_section=\"xs_sc\", pin_length=1 * nm, pin_layer=(1, 10)),\n",
" wc.ConnectivyCheck(\n",
" cross_section=\"xs_sc_auto_widen\", pin_length=1 * nm, pin_layer=(1, 10)\n",
" ),\n",
"]\n",
"rules = [\n",
" wc.write_connectivity_checks_per_section(connectivity_checks=connectivity_checks),\n",
"]\n",
"script = wc.write_drc_deck_macro(rules=rules, layers=None)"
]
}
],
"metadata": {
Expand All @@ -240,6 +178,18 @@
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 45f2250

Please sign in to comment.