Skip to content

Commit

Permalink
Google Java Format
Browse files Browse the repository at this point in the history
  • Loading branch information
github-actions committed Oct 15, 2023
1 parent 24cf405 commit 6b893a1
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 114 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,38 +10,38 @@
import retrofit2.Retrofit;

public class LLamaQuickstart extends Endpoint {
private final Retrofit retrofit = RetrofitClientInstance.getInstance();
private final Llama2Service llama2Service = retrofit.create(Llama2Service.class);
private final ModelMapper modelMapper = new ModelMapper();
private String query;
public LLamaQuickstart() {
}

public LLamaQuickstart(String url, RetryPolicy retryPolicy) {
super(url, retryPolicy);
}

public LLamaQuickstart(String url, RetryPolicy retryPolicy, String query) {
super(url, retryPolicy);
this.query = query;
}

public String getQuery() {
return query;
}

public void setQuery(String query) {
this.query = query;
}


public Observable<String> chatCompletion(String query, ArkRequest arkRequest) {
LLamaQuickstart mapper = modelMapper.map(this, LLamaQuickstart.class);
mapper.setQuery(query);
return chatCompletion(mapper, arkRequest);
}

private Observable<String> chatCompletion(LLamaQuickstart lLamaQuickstart, ArkRequest arkRequest) {
return Observable.fromSingle(this.llama2Service.llamaCompletion(lLamaQuickstart));
}
private final Retrofit retrofit = RetrofitClientInstance.getInstance();
private final Llama2Service llama2Service = retrofit.create(Llama2Service.class);
private final ModelMapper modelMapper = new ModelMapper();
private String query;

public LLamaQuickstart() {}

public LLamaQuickstart(String url, RetryPolicy retryPolicy) {
super(url, retryPolicy);
}

public LLamaQuickstart(String url, RetryPolicy retryPolicy, String query) {
super(url, retryPolicy);
this.query = query;
}

public String getQuery() {
return query;
}

public void setQuery(String query) {
this.query = query;
}

public Observable<String> chatCompletion(String query, ArkRequest arkRequest) {
LLamaQuickstart mapper = modelMapper.map(this, LLamaQuickstart.class);
mapper.setQuery(query);
return chatCompletion(mapper, arkRequest);
}

private Observable<String> chatCompletion(
LLamaQuickstart lLamaQuickstart, ArkRequest arkRequest) {
return Observable.fromSingle(this.llama2Service.llamaCompletion(lLamaQuickstart));
}
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
package com.edgechain.lib.llama2;

import com.edgechain.lib.endpoint.impl.llm.LLamaQuickstart;
import com.edgechain.lib.endpoint.impl.llm.Llama2Endpoint;
import com.edgechain.lib.llama2.request.LLamaCompletionRequest;
import com.edgechain.lib.llama2.request.Llama2ChatCompletionRequest;
import com.edgechain.lib.llama2.response.Llama2ChatCompletionResponse;
import com.edgechain.lib.rxjava.transformer.observable.EdgeChain;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
Expand All @@ -27,7 +24,7 @@ public class LLamaClient {
private final RestTemplate restTemplate = new RestTemplate();

public EdgeChain<List<String>> createChatCompletion(
LLamaCompletionRequest request, LLamaQuickstart endpoint) {
LLamaCompletionRequest request, LLamaQuickstart endpoint) {
return new EdgeChain<>(
Observable.create(
emitter -> {
Expand All @@ -47,8 +44,7 @@ public EdgeChain<List<String>> createChatCompletion(
restTemplate.postForObject(endpoint.getUrl(), entity, String.class);

List<String> chatCompletionResponse =
objectMapper.readValue(
response, new TypeReference<>() {});
objectMapper.readValue(response, new TypeReference<>() {});
emitter.onNext(chatCompletionResponse);
emitter.onComplete();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,71 +21,69 @@

@Service
public class Llama2Client {
@Autowired
private ObjectMapper objectMapper;
private final Logger logger = LoggerFactory.getLogger(getClass());
private final RestTemplate restTemplate = new RestTemplate();

public EdgeChain<List<Llama2ChatCompletionResponse>> createChatCompletion(
Llama2ChatCompletionRequest request, Llama2Endpoint endpoint) {
return new EdgeChain<>(
Observable.create(
emitter -> {
try {

logger.info("Logging ChatCompletion....");

logger.info("==============REQUEST DATA================");
logger.info(request.toString());

// Create headers
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
HttpEntity<Llama2ChatCompletionRequest> entity = new HttpEntity<>(request, headers);
//
String response =
restTemplate.postForObject(endpoint.getUrl(), entity, String.class);

List<Llama2ChatCompletionResponse> chatCompletionResponse =
objectMapper.readValue(
response, new TypeReference<>() {
});
emitter.onNext(chatCompletionResponse);
emitter.onComplete();

} catch (final Exception e) {
emitter.onError(e);
}
}),
endpoint);
}

public EdgeChain<String> createGetChatCompletion(LLamaQuickstart endpoint) {
return new EdgeChain<>(
Observable.create(
emitter -> {
try {

// Create headers
HttpHeaders headers = new HttpHeaders();
headers.set("User-Agent", "insomnia/8.2.0");
HttpEntity<?> entity = new HttpEntity<>(headers);

Map<String, String> param = Collections.singletonMap("query", endpoint.getQuery());

String endpointUrl = endpoint.getUrl() + "?query={query}";

ResponseEntity<String> response = restTemplate.exchange(endpointUrl, HttpMethod.GET, entity, String.class, param);

logger.info("\nRESPONSE DATA {}\n", response.getBody());

emitter.onNext(response.getBody());
emitter.onComplete();

} catch (final Exception e) {
emitter.onError(e);
}
}),
endpoint);
}
@Autowired private ObjectMapper objectMapper;
private final Logger logger = LoggerFactory.getLogger(getClass());
private final RestTemplate restTemplate = new RestTemplate();

public EdgeChain<List<Llama2ChatCompletionResponse>> createChatCompletion(
Llama2ChatCompletionRequest request, Llama2Endpoint endpoint) {
return new EdgeChain<>(
Observable.create(
emitter -> {
try {

logger.info("Logging ChatCompletion....");

logger.info("==============REQUEST DATA================");
logger.info(request.toString());

// Create headers
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
HttpEntity<Llama2ChatCompletionRequest> entity = new HttpEntity<>(request, headers);
//
String response =
restTemplate.postForObject(endpoint.getUrl(), entity, String.class);

List<Llama2ChatCompletionResponse> chatCompletionResponse =
objectMapper.readValue(response, new TypeReference<>() {});
emitter.onNext(chatCompletionResponse);
emitter.onComplete();

} catch (final Exception e) {
emitter.onError(e);
}
}),
endpoint);
}

public EdgeChain<String> createGetChatCompletion(LLamaQuickstart endpoint) {
return new EdgeChain<>(
Observable.create(
emitter -> {
try {

// Create headers
HttpHeaders headers = new HttpHeaders();
headers.set("User-Agent", "insomnia/8.2.0");
HttpEntity<?> entity = new HttpEntity<>(headers);

Map<String, String> param = Collections.singletonMap("query", endpoint.getQuery());

String endpointUrl = endpoint.getUrl() + "?query={query}";

ResponseEntity<String> response =
restTemplate.exchange(endpointUrl, HttpMethod.GET, entity, String.class, param);

logger.info("\nRESPONSE DATA {}\n", response.getBody());

emitter.onNext(response.getBody());
emitter.onComplete();

} catch (final Exception e) {
emitter.onError(e);
}
}),
endpoint);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
public class LLamaCompletionRequest {
@JsonProperty("text_inputs")
private String textInputs;

@JsonProperty("return_full_text")
private Boolean returnFullText;

@JsonProperty("top_k")
private Integer topK;

public LLamaCompletionRequest() {}

public LLamaCompletionRequest(String textInputs, Boolean returnFullText, Integer topK) {
Expand Down Expand Up @@ -73,7 +76,7 @@ public LlamaSupportChatCompletionRequestBuilder returnFullText(Boolean returnFul
return this;
}

public LlamaSupportChatCompletionRequestBuilder topK(Integer topK){
public LlamaSupportChatCompletionRequestBuilder topK(Integer topK) {
this.topK = topK;
return this;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,16 @@
import com.edgechain.lib.endpoint.impl.llm.LLamaQuickstart;
import com.edgechain.lib.endpoint.impl.llm.Llama2Endpoint;
import com.edgechain.lib.llama2.response.Llama2ChatCompletionResponse;
import com.edgechain.lib.request.ArkRequest;
import io.reactivex.rxjava3.core.Single;
import retrofit2.http.Body;
import retrofit2.http.GET;
import retrofit2.http.POST;
import retrofit2.http.Query;

import java.util.List;

public interface Llama2Service {
@POST(value = "llama/chat-completion")
Single<List<Llama2ChatCompletionResponse>> chatCompletion(@Body Llama2Endpoint llama2Endpoint);

@POST(value = "llama/chat-completion")
Single<String> llamaCompletion(@Body LLamaQuickstart lLamaQuickstart);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ public class Llama2Controller {
@PostMapping(value = "/chat-completion")
public Single<String> getChatCompletion(@RequestBody LLamaQuickstart endpoint) {

EdgeChain<String> edgeChain =
llama2Client.createGetChatCompletion(endpoint);
EdgeChain<String> edgeChain = llama2Client.createGetChatCompletion(endpoint);
return edgeChain.toSingle();
}
}

0 comments on commit 6b893a1

Please sign in to comment.