diff --git a/backend/src/main/java/com/llm_service/llm_service/controller/conversation/ConversationApiMapper.java b/backend/src/main/java/com/llm_service/llm_service/controller/conversation/ConversationApiMapper.java index 53be488..2603ee2 100644 --- a/backend/src/main/java/com/llm_service/llm_service/controller/conversation/ConversationApiMapper.java +++ b/backend/src/main/java/com/llm_service/llm_service/controller/conversation/ConversationApiMapper.java @@ -17,7 +17,7 @@ public interface ConversationApiMapper { @Mapping(target = "title", source = "title", defaultValue = "Untitled") @Mapping(target = "discussions", source = "discussions") - ConversationResponseCompact mapList(Conversation conversation); + ConversationResponseCompact mapCompact(Conversation conversation); default List defaultDiscussions() { return new ArrayList<>(); diff --git a/backend/src/main/java/com/llm_service/llm_service/controller/conversation/ConversationController.java b/backend/src/main/java/com/llm_service/llm_service/controller/conversation/ConversationController.java index 9163e78..6a7da75 100644 --- a/backend/src/main/java/com/llm_service/llm_service/controller/conversation/ConversationController.java +++ b/backend/src/main/java/com/llm_service/llm_service/controller/conversation/ConversationController.java @@ -2,6 +2,7 @@ import com.llm_service.llm_service.dto.Conversation; import com.llm_service.llm_service.dto.Discussion; +import com.llm_service.llm_service.exception.UnAuthorizedException; import com.llm_service.llm_service.exception.conversation.ConversationNotFoundException; import com.llm_service.llm_service.service.ConversationService; import io.swagger.v3.oas.annotations.Operation; @@ -34,9 +35,9 @@ public class ConversationController { }) @Operation(summary = "Get all conversations") @GetMapping - public ResponseEntity> getAllConversations() { + public ResponseEntity> getAllConversations() throws UnAuthorizedException { return ResponseEntity.ok(conversationService.getAll().stream() - .map(conversationApiMapper::mapList) + .map(conversationApiMapper::mapCompact) .toList()); } @@ -50,7 +51,7 @@ public ResponseEntity> getAllConversations() { @Operation(summary = "Get conversation by ID") @GetMapping("/{id}") public ResponseEntity getConversationById(@PathVariable UUID id) - throws ConversationNotFoundException { + throws ConversationNotFoundException, UnAuthorizedException { Conversation conversation = conversationService.getByID(id).orElseThrow(() -> new ConversationNotFoundException(id)); @@ -66,7 +67,7 @@ public ResponseEntity getConversationById(@PathVariable UU }) @Operation(summary = "Create new conversation") @PostMapping - public ResponseEntity createConversation() { + public ResponseEntity createConversation() throws Exception { Conversation conversation = conversationService.start(); return ResponseEntity.status(HttpStatus.OK).body(conversationApiMapper.map(conversation)); @@ -83,7 +84,7 @@ public ResponseEntity createConversation() { @PutMapping("/{id}/continue") public ResponseEntity> continueConversation( @PathVariable UUID id, @RequestBody ConversationRequest conversationRequest) - throws ConversationNotFoundException { + throws ConversationNotFoundException, UnAuthorizedException { Conversation conversation = conversationService.getByID(id).orElseThrow(() -> new ConversationNotFoundException(id)); List discussions = conversationService.askLlmQuestion(conversation, conversationRequest.getText()); @@ -100,14 +101,13 @@ public ResponseEntity> continueConversation( }) @Operation(summary = "update conversation title") @PutMapping("/{id}") - public ResponseEntity editConversation( - @PathVariable UUID id, @RequestBody ConversationTitleRequest conversationTitleRequest) - throws ConversationNotFoundException { + public ResponseEntity editConversation( + @PathVariable UUID id, @RequestBody ConversationTitleRequest conversationTitleRequest) throws Exception { Conversation conversation = conversationService.getByID(id).orElseThrow(() -> new ConversationNotFoundException(id)); conversationService.editTitle(conversation, conversationTitleRequest.getTitle()); - return ResponseEntity.status(HttpStatus.OK).body(conversationApiMapper.map(conversation)); + return ResponseEntity.status(HttpStatus.OK).body(conversationApiMapper.mapCompact(conversation)); } @ApiResponses( @@ -119,7 +119,8 @@ public ResponseEntity editConversation( }) @Operation(summary = "deletes a conversation") @DeleteMapping("/{id}") - public ResponseEntity deleteConversation(@PathVariable UUID id) throws ConversationNotFoundException { + public ResponseEntity deleteConversation(@PathVariable UUID id) + throws ConversationNotFoundException, UnAuthorizedException { conversationService.getByID(id).orElseThrow(() -> new ConversationNotFoundException(id)); conversationService.delete(id); return ResponseEntity.status(HttpStatus.OK).body(null); @@ -144,4 +145,9 @@ public ResponseEntity handleConversationNotFoundException( ConversationNotFoundException conversationNotFoundException) { return ResponseEntity.status(HttpStatus.NOT_FOUND).body(conversationNotFoundException.getMessage()); } + + @ExceptionHandler(UnAuthorizedException.class) + public ResponseEntity handleUnAuthorized(UnAuthorizedException unAuthorizedException) { + return ResponseEntity.status(HttpStatus.UNAUTHORIZED).body(unAuthorizedException.getMessage()); + } } diff --git a/backend/src/main/java/com/llm_service/llm_service/exception/UnAuthorizedException.java b/backend/src/main/java/com/llm_service/llm_service/exception/UnAuthorizedException.java new file mode 100644 index 0000000..8fa08b4 --- /dev/null +++ b/backend/src/main/java/com/llm_service/llm_service/exception/UnAuthorizedException.java @@ -0,0 +1,7 @@ +package com.llm_service.llm_service.exception; + +public class UnAuthorizedException extends Exception { + public UnAuthorizedException() { + super("UnAuthorized"); + } +} diff --git a/backend/src/main/java/com/llm_service/llm_service/exception/conversation/ConversationNotFoundException.java b/backend/src/main/java/com/llm_service/llm_service/exception/conversation/ConversationNotFoundException.java index a052d4d..95609c1 100644 --- a/backend/src/main/java/com/llm_service/llm_service/exception/conversation/ConversationNotFoundException.java +++ b/backend/src/main/java/com/llm_service/llm_service/exception/conversation/ConversationNotFoundException.java @@ -6,4 +6,8 @@ public class ConversationNotFoundException extends Exception { public ConversationNotFoundException(UUID id) { super("Conversation with id " + id + " is not found"); } + + public ConversationNotFoundException() { + super("Conversation with is not found"); + } } diff --git a/backend/src/main/java/com/llm_service/llm_service/persistance/entities/ConversationEntity.java b/backend/src/main/java/com/llm_service/llm_service/persistance/entities/ConversationEntity.java index 06c20da..7d5cf5d 100644 --- a/backend/src/main/java/com/llm_service/llm_service/persistance/entities/ConversationEntity.java +++ b/backend/src/main/java/com/llm_service/llm_service/persistance/entities/ConversationEntity.java @@ -15,6 +15,10 @@ public class ConversationEntity extends BaseEntity { @OneToMany(mappedBy = "conversation", orphanRemoval = true) private List discussions; + @ManyToOne + @JoinColumn(name = "user_id") + private UserEntity user; + @Column(name = "title") private String title; diff --git a/backend/src/main/java/com/llm_service/llm_service/persistance/repositories/conversation/ConversationEntityMapper.java b/backend/src/main/java/com/llm_service/llm_service/persistance/repositories/conversation/ConversationEntityMapper.java index 5b27232..e754c0d 100644 --- a/backend/src/main/java/com/llm_service/llm_service/persistance/repositories/conversation/ConversationEntityMapper.java +++ b/backend/src/main/java/com/llm_service/llm_service/persistance/repositories/conversation/ConversationEntityMapper.java @@ -1,13 +1,16 @@ package com.llm_service.llm_service.persistance.repositories.conversation; import com.llm_service.llm_service.dto.Conversation; +import com.llm_service.llm_service.dto.User; import com.llm_service.llm_service.persistance.entities.ConversationEntity; import org.mapstruct.Mapper; +import org.mapstruct.Mapping; @Mapper(componentModel = "spring") public interface ConversationEntityMapper { Conversation map(ConversationEntity conversationEntity); - ConversationEntity map(Conversation conversation); + @Mapping(source = "conversation.id", target = "id") + ConversationEntity map(Conversation conversation, User user); } diff --git a/backend/src/main/java/com/llm_service/llm_service/persistance/repositories/conversation/ConversationJpaPersistenceManager.java b/backend/src/main/java/com/llm_service/llm_service/persistance/repositories/conversation/ConversationJpaPersistenceManager.java index f1916f3..2cd09ec 100644 --- a/backend/src/main/java/com/llm_service/llm_service/persistance/repositories/conversation/ConversationJpaPersistenceManager.java +++ b/backend/src/main/java/com/llm_service/llm_service/persistance/repositories/conversation/ConversationJpaPersistenceManager.java @@ -1,6 +1,7 @@ package com.llm_service.llm_service.persistance.repositories.conversation; import com.llm_service.llm_service.dto.Conversation; +import com.llm_service.llm_service.dto.User; import java.util.List; import java.util.Optional; import java.util.UUID; @@ -14,20 +15,25 @@ public class ConversationJpaPersistenceManager implements ConversationPersistenc private final ConversationEntityMapper conversationEntityMapper; @Override - public List findAll() { + public List findAll(UUID userId) { return conversationRepository.findAll().stream() + .filter(conversationEntity -> conversationEntity.getUser().getId() == userId) .map(conversationEntityMapper::map) .toList(); } @Override - public Optional findById(UUID id) { - return conversationRepository.findById(id).map(conversationEntityMapper::map); + public Optional findById(UUID id, UUID userId) { + return conversationRepository + .findById(id) + .filter(conversationEntity -> conversationEntity.getUser().getId() == userId) + .map(conversationEntityMapper::map); } @Override - public Conversation save(Conversation conversation) { - return conversationEntityMapper.map(conversationRepository.save(conversationEntityMapper.map(conversation))); + public Conversation save(Conversation conversation, User user) { + return conversationEntityMapper.map( + conversationRepository.save(conversationEntityMapper.map(conversation, user))); } @Override diff --git a/backend/src/main/java/com/llm_service/llm_service/persistance/repositories/conversation/ConversationPersistenceManager.java b/backend/src/main/java/com/llm_service/llm_service/persistance/repositories/conversation/ConversationPersistenceManager.java index 89e992d..547d5cb 100644 --- a/backend/src/main/java/com/llm_service/llm_service/persistance/repositories/conversation/ConversationPersistenceManager.java +++ b/backend/src/main/java/com/llm_service/llm_service/persistance/repositories/conversation/ConversationPersistenceManager.java @@ -1,16 +1,17 @@ package com.llm_service.llm_service.persistance.repositories.conversation; import com.llm_service.llm_service.dto.Conversation; +import com.llm_service.llm_service.dto.User; import java.util.List; import java.util.Optional; import java.util.UUID; public interface ConversationPersistenceManager { - List findAll(); + List findAll(UUID userId); - Optional findById(UUID id); + Optional findById(UUID id, UUID userId); - Conversation save(Conversation conversation); + Conversation save(Conversation conversation, User user); void delete(UUID id); diff --git a/backend/src/main/java/com/llm_service/llm_service/persistance/repositories/discussion/DiscussionJpaPersistenceManager.java b/backend/src/main/java/com/llm_service/llm_service/persistance/repositories/discussion/DiscussionJpaPersistenceManager.java index 27fadce..342094a 100644 --- a/backend/src/main/java/com/llm_service/llm_service/persistance/repositories/discussion/DiscussionJpaPersistenceManager.java +++ b/backend/src/main/java/com/llm_service/llm_service/persistance/repositories/discussion/DiscussionJpaPersistenceManager.java @@ -2,6 +2,7 @@ import com.llm_service.llm_service.dto.Conversation; import com.llm_service.llm_service.dto.Discussion; +import com.llm_service.llm_service.dto.User; import com.llm_service.llm_service.persistance.entities.DiscussionEntity; import com.llm_service.llm_service.persistance.repositories.conversation.ConversationEntityMapper; import java.util.List; @@ -30,11 +31,11 @@ public Optional findById(UUID id) { } @Override - public Discussion save(Discussion discussion, Conversation conversation) { + public Discussion save(Discussion discussion, Conversation conversation, User user) { DiscussionEntity discussionEntity = DiscussionEntity.builder() .text(discussion.getText()) .promptRole(discussion.getPromptRole()) - .conversation(conversationEntityMapper.map(conversation)) + .conversation(conversationEntityMapper.map(conversation, user)) .build(); return discussionEntityMapper.map(discussionRepository.save(discussionEntity)); } diff --git a/backend/src/main/java/com/llm_service/llm_service/persistance/repositories/discussion/DiscussionPersistenceManager.java b/backend/src/main/java/com/llm_service/llm_service/persistance/repositories/discussion/DiscussionPersistenceManager.java index 500427b..7702aa6 100644 --- a/backend/src/main/java/com/llm_service/llm_service/persistance/repositories/discussion/DiscussionPersistenceManager.java +++ b/backend/src/main/java/com/llm_service/llm_service/persistance/repositories/discussion/DiscussionPersistenceManager.java @@ -2,6 +2,7 @@ import com.llm_service.llm_service.dto.Conversation; import com.llm_service.llm_service.dto.Discussion; +import com.llm_service.llm_service.dto.User; import java.util.List; import java.util.Optional; import java.util.UUID; @@ -11,7 +12,7 @@ public interface DiscussionPersistenceManager { Optional findById(UUID id); - Discussion save(Discussion discussion, Conversation conversation); + Discussion save(Discussion discussion, Conversation conversation, User user); void delete(UUID id); } diff --git a/backend/src/main/java/com/llm_service/llm_service/service/ConversationService.java b/backend/src/main/java/com/llm_service/llm_service/service/ConversationService.java index fd10205..294b9f7 100644 --- a/backend/src/main/java/com/llm_service/llm_service/service/ConversationService.java +++ b/backend/src/main/java/com/llm_service/llm_service/service/ConversationService.java @@ -2,9 +2,12 @@ import com.llm_service.llm_service.dto.Conversation; import com.llm_service.llm_service.dto.Discussion; +import com.llm_service.llm_service.dto.User; +import com.llm_service.llm_service.exception.UnAuthorizedException; import com.llm_service.llm_service.persistance.entities.DiscussionRole; import com.llm_service.llm_service.persistance.repositories.conversation.ConversationPersistenceManager; import com.llm_service.llm_service.persistance.repositories.discussion.DiscussionPersistenceManager; +import com.llm_service.llm_service.service.user.UserContext; import java.util.ArrayList; import java.util.List; import java.util.Optional; @@ -18,33 +21,60 @@ public class ConversationService { private final LLMService llmService; private final ConversationPersistenceManager conversationPersistenceManager; private final DiscussionPersistenceManager discussionPersistenceManager; + private final UserContext userContext; + + public Conversation start() throws Exception { + Optional user = userContext.getUserFromContext(); + + if (user.isEmpty()) { + // TODO fix it later + throw new Exception(""); + } - public Conversation start() { Conversation conversation = Conversation.builder().discussions(null).build(); - return conversationPersistenceManager.save(conversation); + return conversationPersistenceManager.save(conversation, user.get()); } - public List getAll() { - return conversationPersistenceManager.findAll(); + public List getAll() throws UnAuthorizedException { + Optional user = userContext.getUserFromContext(); + + if (user.isEmpty()) { + throw new UnAuthorizedException(); + } + + return conversationPersistenceManager.findAll(user.get().getId()); } - public Optional getByID(UUID id) { - return conversationPersistenceManager.findById(id); + public Optional getByID(UUID id) throws UnAuthorizedException { + Optional user = userContext.getUserFromContext(); + if (user.isEmpty()) { + throw new UnAuthorizedException(); + } + + return conversationPersistenceManager.findById(id, user.get().getId()); } // TODO optimize this fetching mechanism - public List askLlmQuestion(Conversation conversation, String text) { + public List askLlmQuestion(Conversation conversation, String text) throws UnAuthorizedException { + + Optional user = userContext.getUserFromContext(); + + if (user.isEmpty()) { + throw new UnAuthorizedException(); + } + Discussion discussionFromUserParam = Discussion.builder().promptRole(DiscussionRole.USER).text(text).build(); - Discussion discussionFromUser = discussionPersistenceManager.save(discussionFromUserParam, conversation); + Discussion discussionFromUser = + discussionPersistenceManager.save(discussionFromUserParam, conversation, user.get()); if (conversation.getDiscussions().isEmpty()) { String conversationTitle = discussionFromUser .getText() .substring(0, Math.min(20, discussionFromUser.getText().length())); conversationPersistenceManager.save( - conversation.toBuilder().title(conversationTitle).build()); + conversation.toBuilder().title(conversationTitle).build(), user.get()); } String prompt = initializeModel(); @@ -58,7 +88,7 @@ public List askLlmQuestion(Conversation conversation, String text) { .build(); Discussion discussionFromAssistance = - discussionPersistenceManager.save(discussionFromAssistanceParam, conversation); + discussionPersistenceManager.save(discussionFromAssistanceParam, conversation, user.get()); List newDiscussions = new ArrayList<>(); @@ -76,9 +106,16 @@ public void deleteAll() { conversationPersistenceManager.deleteAll(); } - public void editTitle(Conversation conversation, String title) { + public void editTitle(Conversation conversation, String title) throws Exception { + Optional user = userContext.getUserFromContext(); + + if (user.isEmpty()) { + // TODO fix it later + throw new Exception(""); + } + conversationPersistenceManager.save( - conversation.toBuilder().title(title).build()); + conversation.toBuilder().title(title).build(), user.get()); } private String getPrediction(String text) { diff --git a/backend/src/main/java/com/llm_service/llm_service/service/jwt/filter/JwtAuthenticationFilter.java b/backend/src/main/java/com/llm_service/llm_service/service/jwt/filter/JwtAuthenticationFilter.java index ebfdc97..c0030f2 100644 --- a/backend/src/main/java/com/llm_service/llm_service/service/jwt/filter/JwtAuthenticationFilter.java +++ b/backend/src/main/java/com/llm_service/llm_service/service/jwt/filter/JwtAuthenticationFilter.java @@ -6,6 +6,7 @@ import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; import java.io.IOException; +import lombok.RequiredArgsConstructor; import org.springframework.lang.NonNull; import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.context.SecurityContextHolder; @@ -16,16 +17,12 @@ import org.springframework.web.filter.OncePerRequestFilter; @Component +@RequiredArgsConstructor public class JwtAuthenticationFilter extends OncePerRequestFilter { private final JwtService jwtService; private final UserDetailsService userDetailsService; - public JwtAuthenticationFilter(JwtService jwtService, UserDetailsService userDetailsService) { - this.jwtService = jwtService; - this.userDetailsService = userDetailsService; - } - @Override protected void doFilterInternal( @NonNull HttpServletRequest request, diff --git a/backend/src/main/java/com/llm_service/llm_service/service/user/UserContext.java b/backend/src/main/java/com/llm_service/llm_service/service/user/UserContext.java new file mode 100644 index 0000000..f664e68 --- /dev/null +++ b/backend/src/main/java/com/llm_service/llm_service/service/user/UserContext.java @@ -0,0 +1,33 @@ +package com.llm_service.llm_service.service.user; + +import com.llm_service.llm_service.dto.User; +import com.llm_service.llm_service.persistance.repositories.user.UserPersistenceManager; +import java.util.Optional; +import lombok.RequiredArgsConstructor; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.userdetails.UserDetails; +import org.springframework.stereotype.Service; + +@Service +@RequiredArgsConstructor +public class UserContext { + private final UserPersistenceManager userPersistenceManager; + + public Optional getUserFromContext() { + // Get the authentication object from SecurityContextHolder + Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); + + if (authentication != null && authentication.getPrincipal() instanceof UserDetails) { + UserDetails userDetails = (UserDetails) authentication.getPrincipal(); + + // Access user information from UserDetails + String username = userDetails.getUsername(); + // You can access other user details as well, like authorities: userDetails.getAuthorities() + + return userPersistenceManager.findByUsername(username); + } + + return Optional.empty(); + } +}