Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug (jasp): Error when a function is passed as an argument for a RUS trial_function #114

Open
renezander90 opened this issue Dec 30, 2024 · 4 comments
Assignees

Comments

@renezander90
Copy link
Contributor

import numpy as np
from qrisp import *
from qrisp.jasp import make_jaspr

def some_function(qv):
    h(qv)

def test(fun):
    
    qf = QuantumFloat(1)
    fun(qf)

    qb = QuantumFloat(1)
    h(qb)

    return measure(qb)==1, qf

def test_fun(i):

    qv = RUS(test)(some_function)

    return qv

jaspr = make_jaspr(test_fun)(1)
print(jaspr)

Yields:

TypeError: Cannot interpret value of type <class 'function'> as an abstract array; it does not have a dtype attribute

@renezander90
Copy link
Contributor Author

There is also a bug when an array is passed as an argument:

import numpy as np
import jax.numpy as jnp
from qrisp import *
from qrisp.jasp import make_jaspr

def test(n):
    
    qf = QuantumFloat(1)

    qb = QuantumFloat(1)
    h(qb)

    return measure(qb)==1, qf

#@terminal_sampling
@jaspify
def main():
     
    A = jnp.array([[1, 0], 
              [1, 0]])

    qv = RUS(test)(A)

    return measure(qv)

#jaspr = make_jaspr(main)()
#print(jaspr)

main()

Yields:

Exception: Tried to evaluate jaxpr with insufficient arguments

In this case, make_jaspr does not yield an error.

@positr0nium
Copy link
Contributor

Fixed in 3e26aa9.
It is now possible to specify static arguments. Examples on how to use this feature can be found in the documentation:

**Static arguments**

@renezander90
Copy link
Contributor Author

Another issue:

from qrisp.jasp import *
from qrisp import *
import jax.numpy as jnp

@RUS
def block_encoding(n):

    qb = QuantumFloat(1)

    case_indicator = QuantumFloat(n)
    case_indicator_qubits = [case_indicator[i] for i in jrange(n)]
    
    for i in jrange(n):
        with control(case_indicator_qubits[:i+1], ctrl_state = 2**i-1):
            x(qb)
    
    return (measure(qb) == 0), case_indicator

@terminal_sampling
def main():
    
    qf = block_encoding(4)
    
    return qf
        
res_dict = main()
print(res_dict)

Yields:

TracerIntegerConversionError: The index() method was called on traced array with shape int32[].

(There is no error without the control environment.)

@positr0nium
Copy link
Contributor

positr0nium commented Jan 1, 2025

This might be more of a missunderstanding issue. The control qubits of the ControlEnvironment have to be a static list. This is why we convert the QuantumVariable, which is a dynamic quantity. If we know the size is static, we can turn it into the list with a simple static loop. But here, n is a dynamic number. The ctrl state also has to be a static integer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants