diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/llm/LLamaQuickstart.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/llm/LLamaQuickstart.java index 680e13d65..03e52f6da 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/llm/LLamaQuickstart.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/llm/LLamaQuickstart.java @@ -10,38 +10,38 @@ import retrofit2.Retrofit; public class LLamaQuickstart extends Endpoint { - private final Retrofit retrofit = RetrofitClientInstance.getInstance(); - private final Llama2Service llama2Service = retrofit.create(Llama2Service.class); - private final ModelMapper modelMapper = new ModelMapper(); - private String query; - public LLamaQuickstart() { - } - - public LLamaQuickstart(String url, RetryPolicy retryPolicy) { - super(url, retryPolicy); - } - - public LLamaQuickstart(String url, RetryPolicy retryPolicy, String query) { - super(url, retryPolicy); - this.query = query; - } - - public String getQuery() { - return query; - } - - public void setQuery(String query) { - this.query = query; - } - - - public Observable chatCompletion(String query, ArkRequest arkRequest) { - LLamaQuickstart mapper = modelMapper.map(this, LLamaQuickstart.class); - mapper.setQuery(query); - return chatCompletion(mapper, arkRequest); - } - - private Observable chatCompletion(LLamaQuickstart lLamaQuickstart, ArkRequest arkRequest) { - return Observable.fromSingle(this.llama2Service.llamaCompletion(lLamaQuickstart)); - } + private final Retrofit retrofit = RetrofitClientInstance.getInstance(); + private final Llama2Service llama2Service = retrofit.create(Llama2Service.class); + private final ModelMapper modelMapper = new ModelMapper(); + private String query; + + public LLamaQuickstart() {} + + public LLamaQuickstart(String url, RetryPolicy retryPolicy) { + super(url, retryPolicy); + } + + public LLamaQuickstart(String url, RetryPolicy retryPolicy, String query) { + super(url, retryPolicy); + this.query = query; + } + + public String getQuery() { + return query; + } + + public void setQuery(String query) { + this.query = query; + } + + public Observable chatCompletion(String query, ArkRequest arkRequest) { + LLamaQuickstart mapper = modelMapper.map(this, LLamaQuickstart.class); + mapper.setQuery(query); + return chatCompletion(mapper, arkRequest); + } + + private Observable chatCompletion( + LLamaQuickstart lLamaQuickstart, ArkRequest arkRequest) { + return Observable.fromSingle(this.llama2Service.llamaCompletion(lLamaQuickstart)); + } } diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/LLamaClient.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/LLamaClient.java index 6343bd8a9..ee47fe298 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/LLamaClient.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/LLamaClient.java @@ -1,10 +1,7 @@ package com.edgechain.lib.llama2; import com.edgechain.lib.endpoint.impl.llm.LLamaQuickstart; -import com.edgechain.lib.endpoint.impl.llm.Llama2Endpoint; import com.edgechain.lib.llama2.request.LLamaCompletionRequest; -import com.edgechain.lib.llama2.request.Llama2ChatCompletionRequest; -import com.edgechain.lib.llama2.response.Llama2ChatCompletionResponse; import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; @@ -27,7 +24,7 @@ public class LLamaClient { private final RestTemplate restTemplate = new RestTemplate(); public EdgeChain> createChatCompletion( - LLamaCompletionRequest request, LLamaQuickstart endpoint) { + LLamaCompletionRequest request, LLamaQuickstart endpoint) { return new EdgeChain<>( Observable.create( emitter -> { @@ -47,8 +44,7 @@ public EdgeChain> createChatCompletion( restTemplate.postForObject(endpoint.getUrl(), entity, String.class); List chatCompletionResponse = - objectMapper.readValue( - response, new TypeReference<>() {}); + objectMapper.readValue(response, new TypeReference<>() {}); emitter.onNext(chatCompletionResponse); emitter.onComplete(); diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/Llama2Client.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/Llama2Client.java index 564777b7d..481a890f8 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/Llama2Client.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/Llama2Client.java @@ -21,71 +21,69 @@ @Service public class Llama2Client { - @Autowired - private ObjectMapper objectMapper; - private final Logger logger = LoggerFactory.getLogger(getClass()); - private final RestTemplate restTemplate = new RestTemplate(); - - public EdgeChain> createChatCompletion( - Llama2ChatCompletionRequest request, Llama2Endpoint endpoint) { - return new EdgeChain<>( - Observable.create( - emitter -> { - try { - - logger.info("Logging ChatCompletion...."); - - logger.info("==============REQUEST DATA================"); - logger.info(request.toString()); - - // Create headers - HttpHeaders headers = new HttpHeaders(); - headers.setContentType(MediaType.APPLICATION_JSON); - HttpEntity entity = new HttpEntity<>(request, headers); - // - String response = - restTemplate.postForObject(endpoint.getUrl(), entity, String.class); - - List chatCompletionResponse = - objectMapper.readValue( - response, new TypeReference<>() { - }); - emitter.onNext(chatCompletionResponse); - emitter.onComplete(); - - } catch (final Exception e) { - emitter.onError(e); - } - }), - endpoint); - } - - public EdgeChain createGetChatCompletion(LLamaQuickstart endpoint) { - return new EdgeChain<>( - Observable.create( - emitter -> { - try { - - // Create headers - HttpHeaders headers = new HttpHeaders(); - headers.set("User-Agent", "insomnia/8.2.0"); - HttpEntity entity = new HttpEntity<>(headers); - - Map param = Collections.singletonMap("query", endpoint.getQuery()); - - String endpointUrl = endpoint.getUrl() + "?query={query}"; - - ResponseEntity response = restTemplate.exchange(endpointUrl, HttpMethod.GET, entity, String.class, param); - - logger.info("\nRESPONSE DATA {}\n", response.getBody()); - - emitter.onNext(response.getBody()); - emitter.onComplete(); - - } catch (final Exception e) { - emitter.onError(e); - } - }), - endpoint); - } + @Autowired private ObjectMapper objectMapper; + private final Logger logger = LoggerFactory.getLogger(getClass()); + private final RestTemplate restTemplate = new RestTemplate(); + + public EdgeChain> createChatCompletion( + Llama2ChatCompletionRequest request, Llama2Endpoint endpoint) { + return new EdgeChain<>( + Observable.create( + emitter -> { + try { + + logger.info("Logging ChatCompletion...."); + + logger.info("==============REQUEST DATA================"); + logger.info(request.toString()); + + // Create headers + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON); + HttpEntity entity = new HttpEntity<>(request, headers); + // + String response = + restTemplate.postForObject(endpoint.getUrl(), entity, String.class); + + List chatCompletionResponse = + objectMapper.readValue(response, new TypeReference<>() {}); + emitter.onNext(chatCompletionResponse); + emitter.onComplete(); + + } catch (final Exception e) { + emitter.onError(e); + } + }), + endpoint); + } + + public EdgeChain createGetChatCompletion(LLamaQuickstart endpoint) { + return new EdgeChain<>( + Observable.create( + emitter -> { + try { + + // Create headers + HttpHeaders headers = new HttpHeaders(); + headers.set("User-Agent", "insomnia/8.2.0"); + HttpEntity entity = new HttpEntity<>(headers); + + Map param = Collections.singletonMap("query", endpoint.getQuery()); + + String endpointUrl = endpoint.getUrl() + "?query={query}"; + + ResponseEntity response = + restTemplate.exchange(endpointUrl, HttpMethod.GET, entity, String.class, param); + + logger.info("\nRESPONSE DATA {}\n", response.getBody()); + + emitter.onNext(response.getBody()); + emitter.onComplete(); + + } catch (final Exception e) { + emitter.onError(e); + } + }), + endpoint); + } } diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/request/LLamaCompletionRequest.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/request/LLamaCompletionRequest.java index 6b9733486..eea4b0d99 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/request/LLamaCompletionRequest.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/request/LLamaCompletionRequest.java @@ -7,10 +7,13 @@ public class LLamaCompletionRequest { @JsonProperty("text_inputs") private String textInputs; + @JsonProperty("return_full_text") private Boolean returnFullText; + @JsonProperty("top_k") private Integer topK; + public LLamaCompletionRequest() {} public LLamaCompletionRequest(String textInputs, Boolean returnFullText, Integer topK) { @@ -73,7 +76,7 @@ public LlamaSupportChatCompletionRequestBuilder returnFullText(Boolean returnFul return this; } - public LlamaSupportChatCompletionRequestBuilder topK(Integer topK){ + public LlamaSupportChatCompletionRequestBuilder topK(Integer topK) { this.topK = topK; return this; } diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/Llama2Service.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/Llama2Service.java index 42a081c9f..5f18ace6c 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/Llama2Service.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/Llama2Service.java @@ -3,18 +3,16 @@ import com.edgechain.lib.endpoint.impl.llm.LLamaQuickstart; import com.edgechain.lib.endpoint.impl.llm.Llama2Endpoint; import com.edgechain.lib.llama2.response.Llama2ChatCompletionResponse; -import com.edgechain.lib.request.ArkRequest; import io.reactivex.rxjava3.core.Single; import retrofit2.http.Body; -import retrofit2.http.GET; import retrofit2.http.POST; -import retrofit2.http.Query; import java.util.List; public interface Llama2Service { @POST(value = "llama/chat-completion") Single> chatCompletion(@Body Llama2Endpoint llama2Endpoint); + @POST(value = "llama/chat-completion") Single llamaCompletion(@Body LLamaQuickstart lLamaQuickstart); -} \ No newline at end of file +} diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/llama2/Llama2Controller.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/llama2/Llama2Controller.java index 25bbb520f..d998c28bf 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/llama2/Llama2Controller.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/service/controllers/llama2/Llama2Controller.java @@ -24,8 +24,7 @@ public class Llama2Controller { @PostMapping(value = "/chat-completion") public Single getChatCompletion(@RequestBody LLamaQuickstart endpoint) { - EdgeChain edgeChain = - llama2Client.createGetChatCompletion(endpoint); + EdgeChain edgeChain = llama2Client.createGetChatCompletion(endpoint); return edgeChain.toSingle(); } }