Skip to content

Commit

Permalink
Chore/be user context (#45)
Browse files Browse the repository at this point in the history
* Contextualize the User with Conversation

* Refactor the UnAuthorizedExceptions
  • Loading branch information
Mgrdich authored May 11, 2024
1 parent 2ba4d99 commit 9321804
Show file tree
Hide file tree
Showing 13 changed files with 140 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public interface ConversationApiMapper {

@Mapping(target = "title", source = "title", defaultValue = "Untitled")
@Mapping(target = "discussions", source = "discussions")
ConversationResponseCompact mapList(Conversation conversation);
ConversationResponseCompact mapCompact(Conversation conversation);

default List<Discussion> defaultDiscussions() {
return new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.llm_service.llm_service.dto.Conversation;
import com.llm_service.llm_service.dto.Discussion;
import com.llm_service.llm_service.exception.UnAuthorizedException;
import com.llm_service.llm_service.exception.conversation.ConversationNotFoundException;
import com.llm_service.llm_service.service.ConversationService;
import io.swagger.v3.oas.annotations.Operation;
Expand Down Expand Up @@ -34,9 +35,9 @@ public class ConversationController {
})
@Operation(summary = "Get all conversations")
@GetMapping
public ResponseEntity<List<ConversationResponseCompact>> getAllConversations() {
public ResponseEntity<List<ConversationResponseCompact>> getAllConversations() throws UnAuthorizedException {
return ResponseEntity.ok(conversationService.getAll().stream()
.map(conversationApiMapper::mapList)
.map(conversationApiMapper::mapCompact)
.toList());
}

Expand All @@ -50,7 +51,7 @@ public ResponseEntity<List<ConversationResponseCompact>> getAllConversations() {
@Operation(summary = "Get conversation by ID")
@GetMapping("/{id}")
public ResponseEntity<ConversationResponse> getConversationById(@PathVariable UUID id)
throws ConversationNotFoundException {
throws ConversationNotFoundException, UnAuthorizedException {
Conversation conversation =
conversationService.getByID(id).orElseThrow(() -> new ConversationNotFoundException(id));

Expand All @@ -66,7 +67,7 @@ public ResponseEntity<ConversationResponse> getConversationById(@PathVariable UU
})
@Operation(summary = "Create new conversation")
@PostMapping
public ResponseEntity<ConversationResponse> createConversation() {
public ResponseEntity<ConversationResponse> createConversation() throws Exception {
Conversation conversation = conversationService.start();

return ResponseEntity.status(HttpStatus.OK).body(conversationApiMapper.map(conversation));
Expand All @@ -83,7 +84,7 @@ public ResponseEntity<ConversationResponse> createConversation() {
@PutMapping("/{id}/continue")
public ResponseEntity<List<DiscussionResponse>> continueConversation(
@PathVariable UUID id, @RequestBody ConversationRequest conversationRequest)
throws ConversationNotFoundException {
throws ConversationNotFoundException, UnAuthorizedException {
Conversation conversation =
conversationService.getByID(id).orElseThrow(() -> new ConversationNotFoundException(id));
List<Discussion> discussions = conversationService.askLlmQuestion(conversation, conversationRequest.getText());
Expand All @@ -100,14 +101,13 @@ public ResponseEntity<List<DiscussionResponse>> continueConversation(
})
@Operation(summary = "update conversation title")
@PutMapping("/{id}")
public ResponseEntity<ConversationResponse> editConversation(
@PathVariable UUID id, @RequestBody ConversationTitleRequest conversationTitleRequest)
throws ConversationNotFoundException {
public ResponseEntity<ConversationResponseCompact> editConversation(
@PathVariable UUID id, @RequestBody ConversationTitleRequest conversationTitleRequest) throws Exception {
Conversation conversation =
conversationService.getByID(id).orElseThrow(() -> new ConversationNotFoundException(id));
conversationService.editTitle(conversation, conversationTitleRequest.getTitle());

return ResponseEntity.status(HttpStatus.OK).body(conversationApiMapper.map(conversation));
return ResponseEntity.status(HttpStatus.OK).body(conversationApiMapper.mapCompact(conversation));
}

@ApiResponses(
Expand All @@ -119,7 +119,8 @@ public ResponseEntity<ConversationResponse> editConversation(
})
@Operation(summary = "deletes a conversation")
@DeleteMapping("/{id}")
public ResponseEntity<Void> deleteConversation(@PathVariable UUID id) throws ConversationNotFoundException {
public ResponseEntity<Void> deleteConversation(@PathVariable UUID id)
throws ConversationNotFoundException, UnAuthorizedException {
conversationService.getByID(id).orElseThrow(() -> new ConversationNotFoundException(id));
conversationService.delete(id);
return ResponseEntity.status(HttpStatus.OK).body(null);
Expand All @@ -144,4 +145,9 @@ public ResponseEntity<String> handleConversationNotFoundException(
ConversationNotFoundException conversationNotFoundException) {
return ResponseEntity.status(HttpStatus.NOT_FOUND).body(conversationNotFoundException.getMessage());
}

@ExceptionHandler(UnAuthorizedException.class)
public ResponseEntity<String> handleUnAuthorized(UnAuthorizedException unAuthorizedException) {
return ResponseEntity.status(HttpStatus.UNAUTHORIZED).body(unAuthorizedException.getMessage());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package com.llm_service.llm_service.exception;

public class UnAuthorizedException extends Exception {
public UnAuthorizedException() {
super("UnAuthorized");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,8 @@ public class ConversationNotFoundException extends Exception {
public ConversationNotFoundException(UUID id) {
super("Conversation with id " + id + " is not found");
}

public ConversationNotFoundException() {
super("Conversation with is not found");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ public class ConversationEntity extends BaseEntity {
@OneToMany(mappedBy = "conversation", orphanRemoval = true)
private List<DiscussionEntity> discussions;

@ManyToOne
@JoinColumn(name = "user_id")
private UserEntity user;

@Column(name = "title")
private String title;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package com.llm_service.llm_service.persistance.repositories.conversation;

import com.llm_service.llm_service.dto.Conversation;
import com.llm_service.llm_service.dto.User;
import com.llm_service.llm_service.persistance.entities.ConversationEntity;
import org.mapstruct.Mapper;
import org.mapstruct.Mapping;

@Mapper(componentModel = "spring")
public interface ConversationEntityMapper {

Conversation map(ConversationEntity conversationEntity);

ConversationEntity map(Conversation conversation);
@Mapping(source = "conversation.id", target = "id")
ConversationEntity map(Conversation conversation, User user);
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.llm_service.llm_service.persistance.repositories.conversation;

import com.llm_service.llm_service.dto.Conversation;
import com.llm_service.llm_service.dto.User;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
Expand All @@ -14,20 +15,25 @@ public class ConversationJpaPersistenceManager implements ConversationPersistenc
private final ConversationEntityMapper conversationEntityMapper;

@Override
public List<Conversation> findAll() {
public List<Conversation> findAll(UUID userId) {
return conversationRepository.findAll().stream()
.filter(conversationEntity -> conversationEntity.getUser().getId() == userId)
.map(conversationEntityMapper::map)
.toList();
}

@Override
public Optional<Conversation> findById(UUID id) {
return conversationRepository.findById(id).map(conversationEntityMapper::map);
public Optional<Conversation> findById(UUID id, UUID userId) {
return conversationRepository
.findById(id)
.filter(conversationEntity -> conversationEntity.getUser().getId() == userId)
.map(conversationEntityMapper::map);
}

@Override
public Conversation save(Conversation conversation) {
return conversationEntityMapper.map(conversationRepository.save(conversationEntityMapper.map(conversation)));
public Conversation save(Conversation conversation, User user) {
return conversationEntityMapper.map(
conversationRepository.save(conversationEntityMapper.map(conversation, user)));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
package com.llm_service.llm_service.persistance.repositories.conversation;

import com.llm_service.llm_service.dto.Conversation;
import com.llm_service.llm_service.dto.User;
import java.util.List;
import java.util.Optional;
import java.util.UUID;

public interface ConversationPersistenceManager {
List<Conversation> findAll();
List<Conversation> findAll(UUID userId);

Optional<Conversation> findById(UUID id);
Optional<Conversation> findById(UUID id, UUID userId);

Conversation save(Conversation conversation);
Conversation save(Conversation conversation, User user);

void delete(UUID id);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.llm_service.llm_service.dto.Conversation;
import com.llm_service.llm_service.dto.Discussion;
import com.llm_service.llm_service.dto.User;
import com.llm_service.llm_service.persistance.entities.DiscussionEntity;
import com.llm_service.llm_service.persistance.repositories.conversation.ConversationEntityMapper;
import java.util.List;
Expand Down Expand Up @@ -30,11 +31,11 @@ public Optional<Discussion> findById(UUID id) {
}

@Override
public Discussion save(Discussion discussion, Conversation conversation) {
public Discussion save(Discussion discussion, Conversation conversation, User user) {
DiscussionEntity discussionEntity = DiscussionEntity.builder()
.text(discussion.getText())
.promptRole(discussion.getPromptRole())
.conversation(conversationEntityMapper.map(conversation))
.conversation(conversationEntityMapper.map(conversation, user))
.build();
return discussionEntityMapper.map(discussionRepository.save(discussionEntity));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.llm_service.llm_service.dto.Conversation;
import com.llm_service.llm_service.dto.Discussion;
import com.llm_service.llm_service.dto.User;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
Expand All @@ -11,7 +12,7 @@ public interface DiscussionPersistenceManager {

Optional<Discussion> findById(UUID id);

Discussion save(Discussion discussion, Conversation conversation);
Discussion save(Discussion discussion, Conversation conversation, User user);

void delete(UUID id);
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

import com.llm_service.llm_service.dto.Conversation;
import com.llm_service.llm_service.dto.Discussion;
import com.llm_service.llm_service.dto.User;
import com.llm_service.llm_service.exception.UnAuthorizedException;
import com.llm_service.llm_service.persistance.entities.DiscussionRole;
import com.llm_service.llm_service.persistance.repositories.conversation.ConversationPersistenceManager;
import com.llm_service.llm_service.persistance.repositories.discussion.DiscussionPersistenceManager;
import com.llm_service.llm_service.service.user.UserContext;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
Expand All @@ -18,33 +21,60 @@ public class ConversationService {
private final LLMService llmService;
private final ConversationPersistenceManager conversationPersistenceManager;
private final DiscussionPersistenceManager discussionPersistenceManager;
private final UserContext userContext;

public Conversation start() throws Exception {
Optional<User> user = userContext.getUserFromContext();

if (user.isEmpty()) {
// TODO fix it later
throw new Exception("");
}

public Conversation start() {
Conversation conversation = Conversation.builder().discussions(null).build();
return conversationPersistenceManager.save(conversation);
return conversationPersistenceManager.save(conversation, user.get());
}

public List<Conversation> getAll() {
return conversationPersistenceManager.findAll();
public List<Conversation> getAll() throws UnAuthorizedException {
Optional<User> user = userContext.getUserFromContext();

if (user.isEmpty()) {
throw new UnAuthorizedException();
}

return conversationPersistenceManager.findAll(user.get().getId());
}

public Optional<Conversation> getByID(UUID id) {
return conversationPersistenceManager.findById(id);
public Optional<Conversation> getByID(UUID id) throws UnAuthorizedException {
Optional<User> user = userContext.getUserFromContext();
if (user.isEmpty()) {
throw new UnAuthorizedException();
}

return conversationPersistenceManager.findById(id, user.get().getId());
}

// TODO optimize this fetching mechanism
public List<Discussion> askLlmQuestion(Conversation conversation, String text) {
public List<Discussion> askLlmQuestion(Conversation conversation, String text) throws UnAuthorizedException {

Optional<User> user = userContext.getUserFromContext();

if (user.isEmpty()) {
throw new UnAuthorizedException();
}

Discussion discussionFromUserParam =
Discussion.builder().promptRole(DiscussionRole.USER).text(text).build();

Discussion discussionFromUser = discussionPersistenceManager.save(discussionFromUserParam, conversation);
Discussion discussionFromUser =
discussionPersistenceManager.save(discussionFromUserParam, conversation, user.get());

if (conversation.getDiscussions().isEmpty()) {
String conversationTitle = discussionFromUser
.getText()
.substring(0, Math.min(20, discussionFromUser.getText().length()));
conversationPersistenceManager.save(
conversation.toBuilder().title(conversationTitle).build());
conversation.toBuilder().title(conversationTitle).build(), user.get());
}

String prompt = initializeModel();
Expand All @@ -58,7 +88,7 @@ public List<Discussion> askLlmQuestion(Conversation conversation, String text) {
.build();

Discussion discussionFromAssistance =
discussionPersistenceManager.save(discussionFromAssistanceParam, conversation);
discussionPersistenceManager.save(discussionFromAssistanceParam, conversation, user.get());

List<Discussion> newDiscussions = new ArrayList<>();

Expand All @@ -76,9 +106,16 @@ public void deleteAll() {
conversationPersistenceManager.deleteAll();
}

public void editTitle(Conversation conversation, String title) {
public void editTitle(Conversation conversation, String title) throws Exception {
Optional<User> user = userContext.getUserFromContext();

if (user.isEmpty()) {
// TODO fix it later
throw new Exception("");
}

conversationPersistenceManager.save(
conversation.toBuilder().title(title).build());
conversation.toBuilder().title(title).build(), user.get());
}

private String getPrediction(String text) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import lombok.RequiredArgsConstructor;
import org.springframework.lang.NonNull;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.context.SecurityContextHolder;
Expand All @@ -16,16 +17,12 @@
import org.springframework.web.filter.OncePerRequestFilter;

@Component
@RequiredArgsConstructor
public class JwtAuthenticationFilter extends OncePerRequestFilter {

private final JwtService jwtService;
private final UserDetailsService userDetailsService;

public JwtAuthenticationFilter(JwtService jwtService, UserDetailsService userDetailsService) {
this.jwtService = jwtService;
this.userDetailsService = userDetailsService;
}

@Override
protected void doFilterInternal(
@NonNull HttpServletRequest request,
Expand Down
Loading

0 comments on commit 9321804

Please sign in to comment.