diff --git a/alphafold3_pytorch/alphafold3.py b/alphafold3_pytorch/alphafold3.py index ff69e41f..f486eeec 100644 --- a/alphafold3_pytorch/alphafold3.py +++ b/alphafold3_pytorch/alphafold3.py @@ -5740,23 +5740,41 @@ def _inhouse_compute_unresolved_rasa( # first constitute the fibonacci sphere - num_surface_dots = fibonacci_sphere_n * 2. + 1 + num_surface_dots = fibonacci_sphere_n * 2 + 1 golden_ratio = 1. + math.sqrt(5.) / 2 + weight = (4. * math.pi) / num_surface_dots arange = torch.arange(-fibonacci_sphere_n, fibonacci_sphere_n + 1) # for example, N = 3 -> [-3, -2, -1, 0, 1, 2, 3] lat = torch.asin((2. * arange) / num_surface_dots) lon = torch.fmod(arange, golden_ratio) * 2 * math.pi / golden_ratio - surface_dots = torch.stack(( + # ein: sd - surface dots + + unit_surface_dots: Float['sd 3'] = torch.stack(( lon.sin() * lat.cos(), lon.cos() * lat.cos(), lat.sin() ), dim = -1) - weight = (4. * math.pi) / num_surface_dots + # overall logic + # assume radius of 1 for all atoms for starters - in algorithm, radius is different depending on backbone + sidechain atoms + + radius = 1. + + atom_rel_pos = einx.subtract('i c, j c-> i j c', atom_pos, atom_pos) + + surface_dots = radius * unit_surface_dots + + dist_from_surface_dots_sq = (einx.subtract('i j c, sd c -> i sd j c') ** 2).sum(dim = -1) + + free = reduce(dist_from_surface_dots_sq > radius, 'i sd j -> i sd', 'all') + + score = reduce(free.float() * weight, 'm sd -> m', 'sum') + + per_atom_accessible_surface_score = reduce(score * radius ** 2, 'm sd -> m') - # rest of logic written by @xluo + # rest written by @xluo rasa = [] aatypes = []