|
128 | 128 | "@patch\n",
|
129 | 129 | "def _initialize_carm(self: Detector):\n",
|
130 | 130 | " \"\"\"Initialize the default position for the source and detector plane.\"\"\"\n",
|
131 |
| - " # Initialize the source on the x-axis\n", |
132 |
| - " source = torch.tensor([[self.sdr, 0.0, 0.0]])\n", |
133 |
| - "\n", |
134 |
| - " # Initialize the center of the detector plane on the negative x-axis\n", |
135 |
| - " center = torch.tensor([[-self.sdr, 0.0, 0.0]])\n", |
| 131 | + " try:\n", |
| 132 | + " device = self.sdr.device\n", |
| 133 | + " except AttributeError:\n", |
| 134 | + " device = torch.device(\"cpu\")\n", |
| 135 | + " \n", |
| 136 | + " # Initialize the source on the x-axis and the center of the detector plane on the negative x-axis\n", |
| 137 | + " source = torch.tensor([[1.0, 0.0, 0.0]], device=device) * self.sdr\n", |
| 138 | + " center = torch.tensor([[-1.0, 0.0, 0.0]], device=device) * self.sdr\n", |
136 | 139 | "\n",
|
137 | 140 | " # Use the standard basis for the detector plane\n",
|
138 |
| - " basis = torch.tensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])\n", |
| 141 | + " basis = torch.tensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], device=device)\n", |
139 | 142 | "\n",
|
140 | 143 | " # Construct the detector plane with different offsets for even or odd heights\n",
|
141 | 144 | " h_off = 1.0 if self.height % 2 else 0.5\n",
|
142 | 145 | " w_off = 1.0 if self.width % 2 else 0.5\n",
|
143 | 146 | "\n",
|
144 | 147 | " # Construct equally spaced points along the basis vectors\n",
|
145 |
| - " t = (torch.arange(-self.height // 2, self.height // 2) + h_off) * self.delx\n", |
146 |
| - " s = (torch.arange(-self.width // 2, self.width // 2) + w_off) * self.dely\n", |
| 148 | + " t = (torch.arange(-self.height // 2, self.height // 2, device=device) + h_off) * self.delx\n", |
| 149 | + " s = (torch.arange(-self.width // 2, self.width // 2, device=device) + w_off) * self.dely\n", |
147 | 150 | " if self.reverse_x_axis:\n",
|
148 | 151 | " s = -s\n",
|
149 | 152 | " coefs = torch.cartesian_prod(t, s).reshape(-1, 2)\n",
|
|
0 commit comments