diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md new file mode 100644 index 00000000..26573c44 --- /dev/null +++ b/.github/copilot-instructions.md @@ -0,0 +1,61 @@ +# Copilot Instructions for ToolNeuron + +## Build & Test + +```bash +# Build debug APK +./gradlew assembleDebug + +# Build release APK (requires signing config in local.properties) +./gradlew assembleRelease + +# Run all unit tests +./gradlew test + +# Run tests for a single module +./gradlew :app:test +./gradlew :memory-vault:test +./gradlew :neuron-packet:test + +# Run a single test class +./gradlew :app:testDebugUnitTest --tests "com.dark.tool_neuron.McpToolMapperTest" +``` + +**Requirements:** JDK 17, Android SDK 36, NDK 26.x. The `neuron-packet` module requires OpenSSL prebuilt libraries — see `neuron-packet/SETUP.md`. + +## Architecture + +This is an Android app (Kotlin + C++) that runs LLMs and Stable Diffusion entirely on-device. It's a multi-module Gradle project: + +- **`app`** — Main application. Jetpack Compose UI, MVVM with Hilt DI, Room database. Package: `com.dark.tool_neuron`. +- **`memory-vault`** — Encrypted binary storage engine with WAL crash recovery, LZ4 compression, full-text and vector indices. Package: `com.memoryvault`. See `docs/MemoryVault.MD` for the storage format spec. +- **`neuron-packet`** — Secure data export/import with AES-256-GCM encryption. Has both Kotlin and C++ (JNI) sides. Package: `com.neuronpacket`. The C++ code lives in `neuron-packet/src/main/cpp/` and builds via CMake. + +### AI inference layer + +Native inference is provided by pre-built AAR libraries in `libs/`: +- `ai_gguf-release.aar` — llama.cpp bindings for text generation (GGUF models) +- `ai_sd-release.aar` — Stable Diffusion 1.5 bindings for image generation + +These are wrapped by engine classes in `app/.../engine/`: +- `GGUFEngine` — loads GGUF models, generates text, supports function calling with tool grammars +- `DiffusionEngine` — loads SD models, generates images +- `EmbeddingEngine` — generates text embeddings for RAG/vector search + +`LLMService` is a bound Android Service that exposes these engines via AIDL IPC. + +### Data flow + +`UI (Compose screens)` → `ViewModel (@HiltViewModel)` → `Repository` → `Room DAO / MemoryVault` + +ViewModels expose `StateFlow` for reactive UI updates. All async work uses Kotlin Coroutines with `viewModelScope`. + +## Key Conventions + +- **DI:** Hilt everywhere. Activities use `@AndroidEntryPoint`, ViewModels use `@HiltViewModel`. All modules are defined in `app/.../di/HiltModules.kt` and installed in `SingletonComponent`. +- **Navigation:** Jetpack Compose NavHost in `MainActivity`. Routes are defined as a `Screen` sealed class. Uses slide + fade transitions. +- **Serialization:** `kotlinx.serialization` for JSON. Room entities live in `models/table_schema/`. +- **NDK targets:** `arm64-v8a` and `x86_64` only. +- **Build config:** Properties are read from `local.properties` or environment variables via `getProperty()` (defined in each module's `build.gradle.kts`). The `ALIAS` property is used for build config fields. +- **UI constants:** Shared sizing/padding values are in `global/Standards.kt`. +- **Version catalog:** All dependency versions are managed in `gradle/libs.versions.toml`. diff --git a/.github/workflows/build-debug-apk.yml b/.github/workflows/build-debug-apk.yml new file mode 100644 index 00000000..e45db48f --- /dev/null +++ b/.github/workflows/build-debug-apk.yml @@ -0,0 +1,35 @@ +name: Build Debug APK + +on: + pull_request: + branches: [ main, master ] + workflow_dispatch: + +jobs: + build: + name: Build Debug APK + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up JDK 17 + uses: actions/setup-java@v4 + with: + java-version: '17' + distribution: 'temurin' + cache: gradle + + - name: Grant execute permission for gradlew + run: chmod +x gradlew + + - name: Build Debug APK + run: ./gradlew assembleDebug --no-daemon + + - name: Upload Debug APK + uses: actions/upload-artifact@v4 + with: + name: app-debug + path: app/build/outputs/apk/debug/app-debug.apk + retention-days: 14 diff --git a/app/build.gradle.kts b/app/build.gradle.kts index 4fb3af7c..34eb0852 100644 --- a/app/build.gradle.kts +++ b/app/build.gradle.kts @@ -152,6 +152,10 @@ dependencies { // Debug debugImplementation(libs.androidx.compose.ui.tooling) + + // Tests + testImplementation(libs.junit) + testImplementation(libs.org.json) } fun getProperty(value: String): String { @@ -163,4 +167,4 @@ fun getProperty(value: String): String { } else { System.getenv(value) ?: "\"sample_val\"" } -} \ No newline at end of file +} diff --git a/app/src/main/AndroidManifest.xml b/app/src/main/AndroidManifest.xml index 25bf8166..c31205b9 100644 --- a/app/src/main/AndroidManifest.xml +++ b/app/src/main/AndroidManifest.xml @@ -5,7 +5,12 @@ + + + + + Screen.Terms.route !termsAccepted -> Screen.Welcome.route - - // Terms accepted but setup not done and no models: go to setup !setupDone && !hasModel -> Screen.OnboardingSetup.route - - // Everything done else -> Screen.Chat.route } } @@ -122,7 +119,6 @@ class MainActivity : ComponentActivity() { override fun onDestroy() { super.onDestroy() - // Clear password cache when app terminates ragRepository.clearPasswordCache() LlmModelWorker.unbindService() AppContainer.shutdown() @@ -130,7 +126,7 @@ class MainActivity : ComponentActivity() { } sealed class Screen(val route: String) { - // Onboarding (flat routes so any can be used as startDestination) + // Onboarding object Welcome : Screen("welcome") object Terms : Screen("terms") object OnboardingSetup : Screen("setup") @@ -141,6 +137,8 @@ sealed class Screen(val route: String) { object Editor : Screen("editor") object Settings : Screen("settings") object VaultManager : Screen("vault_manager") + object McpServers : Screen("mcp_servers") + object McpStore : Screen("mcp_store") object Personas : Screen("personas") object PersonaEditor : Screen("persona_editor/{personaId}") { fun createRoute(personaId: String? = null) = "persona_editor/${personaId ?: "new"}" @@ -157,7 +155,6 @@ fun AppNavigation( val scope = rememberCoroutineScope() val navController = rememberNavController() - // Activity-scoped ViewModels for shared state between Chat and Personas val chatViewModel: ChatViewModel = hiltViewModel() val llmModelViewModel: LLMModelViewModel = hiltViewModel() @@ -207,12 +204,10 @@ fun AppNavigation( termsDataStore.acceptTerms() } if (hasModelsInstalled) { - // Returning user: skip setup, go to chat navController.navigate(Screen.Chat.route) { popUpTo(0) { inclusive = true } } } else { - // New user: proceed to setup navController.navigate(Screen.OnboardingSetup.route) } } @@ -241,6 +236,9 @@ fun AppNavigation( onVaultManagerClick = { navController.navigate(Screen.VaultManager.route) }, + onMcpServersClick = { + navController.navigate(Screen.McpServers.route) + }, onCharacterClick = { navController.navigate(Screen.Personas.route) }, @@ -304,7 +302,6 @@ fun AppNavigation( personaId = personaId, onNavigateBack = { navController.popBackStack() }, onDeleted = { - // If deleted persona was active, clear selection if (personaId != null && chatViewModel.activePersona.value?.id == personaId) { chatViewModel.setActivePersona(null) } @@ -318,5 +315,16 @@ fun AppNavigation( onNavigateBack = { navController.popBackStack() } ) } + + composable(Screen.McpServers.route) { + McpServersScreen( + onBackClick = { navController.popBackStack() }, + onStoreClick = { navController.navigate(Screen.McpStore.route) } + ) + } + + composable(Screen.McpStore.route) { + McpStoreScreen(onBackClick = { navController.popBackStack() }) + } } } diff --git a/app/src/main/java/com/dark/tool_neuron/database/AppDatabase.kt b/app/src/main/java/com/dark/tool_neuron/database/AppDatabase.kt index 22caea0e..0f1e7c41 100644 --- a/app/src/main/java/com/dark/tool_neuron/database/AppDatabase.kt +++ b/app/src/main/java/com/dark/tool_neuron/database/AppDatabase.kt @@ -8,6 +8,7 @@ import androidx.room.TypeConverters import androidx.room.migration.Migration import androidx.sqlite.db.SupportSQLiteDatabase import com.dark.tool_neuron.database.dao.AiMemoryDao +import com.dark.tool_neuron.database.dao.McpServerDao import com.dark.tool_neuron.database.dao.ModelConfigDao import com.dark.tool_neuron.database.dao.ModelDao import com.dark.tool_neuron.database.dao.PersonaDao @@ -15,14 +16,15 @@ import com.dark.tool_neuron.database.dao.RagDao import com.dark.tool_neuron.models.converters.Converters import com.dark.tool_neuron.models.table_schema.AiMemory import com.dark.tool_neuron.models.table_schema.InstalledRag +import com.dark.tool_neuron.models.table_schema.McpServer import com.dark.tool_neuron.models.table_schema.Model import com.dark.tool_neuron.models.table_schema.ModelConfig import com.dark.tool_neuron.models.table_schema.Persona import java.util.UUID @Database( - entities = [Model::class, ModelConfig::class, InstalledRag::class, Persona::class, AiMemory::class], - version = 6, + entities = [Model::class, ModelConfig::class, InstalledRag::class, McpServer::class, Persona::class, AiMemory::class], + version = 7, exportSchema = false ) @TypeConverters(Converters::class) @@ -30,6 +32,7 @@ abstract class AppDatabase : RoomDatabase() { abstract fun modelDao(): ModelDao abstract fun modelConfigDao(): ModelConfigDao abstract fun ragDao(): RagDao + abstract fun mcpServerDao(): McpServerDao abstract fun personaDao(): PersonaDao abstract fun aiMemoryDao(): AiMemoryDao @@ -67,7 +70,6 @@ abstract class AppDatabase : RoomDatabase() { private val MIGRATION_2_3 = object : Migration(2, 3) { override fun migrate(db: SupportSQLiteDatabase) { - // Add missing columns to installed_rags table db.execSQL("ALTER TABLE installed_rags ADD COLUMN is_encrypted INTEGER NOT NULL DEFAULT 0") db.execSQL("ALTER TABLE installed_rags ADD COLUMN loading_mode INTEGER NOT NULL DEFAULT 1") db.execSQL("ALTER TABLE installed_rags ADD COLUMN has_admin_access INTEGER NOT NULL DEFAULT 0") @@ -76,10 +78,6 @@ abstract class AppDatabase : RoomDatabase() { private val MIGRATION_3_4 = object : Migration(3, 4) { override fun migrate(db: SupportSQLiteDatabase) { - // Recreate installed_rags table without DEFAULT constraints in SQL - // Room expects defaults to be handled at the application level, not database level - - // Create new table with correct schema (no DEFAULT clauses) db.execSQL(""" CREATE TABLE IF NOT EXISTS installed_rags_new ( id TEXT PRIMARY KEY NOT NULL, @@ -107,16 +105,12 @@ abstract class AppDatabase : RoomDatabase() { ) """.trimIndent()) - // Copy data from old table to new table db.execSQL(""" INSERT INTO installed_rags_new SELECT * FROM installed_rags """.trimIndent()) - // Drop old table db.execSQL("DROP TABLE installed_rags") - - // Rename new table to original name db.execSQL("ALTER TABLE installed_rags_new RENAME TO installed_rags") } } @@ -151,10 +145,7 @@ abstract class AppDatabase : RoomDatabase() { ) """.trimIndent()) - // Index on ai_memories.category db.execSQL("CREATE INDEX IF NOT EXISTS index_ai_memories_category ON ai_memories (category)") - - // Seed default personas (v5 schema — no character-card columns yet) seedDefaultPersonasV5(db) } } @@ -170,15 +161,34 @@ abstract class AppDatabase : RoomDatabase() { db.execSQL("ALTER TABLE personas ADD COLUMN tags TEXT NOT NULL DEFAULT '[]'") db.execSQL("ALTER TABLE personas ADD COLUMN avatar_uri TEXT") db.execSQL("ALTER TABLE personas ADD COLUMN creator_notes TEXT NOT NULL DEFAULT ''") - // Migrate legacy systemPrompt into description db.execSQL("UPDATE personas SET description = system_prompt WHERE system_prompt != '' AND description = ''") } } - /** - * v5 schema seed — only the 7 original columns. Used by MIGRATION_4_5 - * where the v6 character-card columns don't exist yet. - */ + private val MIGRATION_6_7 = object : Migration(6, 7) { + override fun migrate(db: SupportSQLiteDatabase) { + // Create mcp_servers table + db.execSQL(""" + CREATE TABLE IF NOT EXISTS mcp_servers ( + id TEXT PRIMARY KEY NOT NULL, + name TEXT NOT NULL, + url TEXT NOT NULL, + transportType TEXT NOT NULL, + apiKey TEXT, + isEnabled INTEGER NOT NULL, + lastError TEXT, + createdAt INTEGER NOT NULL, + updatedAt INTEGER NOT NULL, + lastConnectedAt INTEGER, + description TEXT NOT NULL, + customHeadersJson TEXT, + isLocal INTEGER NOT NULL DEFAULT 0, + sourceStoreId TEXT + ) + """.trimIndent()) + } + } + private fun seedDefaultPersonasV5(db: SupportSQLiteDatabase) { val now = System.currentTimeMillis() val cols = "id, name, avatar, system_prompt, greeting, is_default, created_at" @@ -213,10 +223,6 @@ abstract class AppDatabase : RoomDatabase() { ) } - /** - * Full v6 schema seed — includes character-card columns. Used by onCreate - * where Room creates the complete schema (all NOT NULL columns present). - */ private fun seedDefaultPersonas(db: SupportSQLiteDatabase) { val now = System.currentTimeMillis() val cols = "id, name, avatar, system_prompt, greeting, is_default, created_at, description, personality, scenario, example_messages, alternate_greetings, tags, creator_notes" @@ -259,7 +265,7 @@ abstract class AppDatabase : RoomDatabase() { AppDatabase::class.java, "llm_models_database" ) - .addMigrations(MIGRATION_1_2, MIGRATION_2_3, MIGRATION_3_4, MIGRATION_4_5, MIGRATION_5_6) + .addMigrations(MIGRATION_1_2, MIGRATION_2_3, MIGRATION_3_4, MIGRATION_4_5, MIGRATION_5_6, MIGRATION_6_7) .addCallback(object : Callback() { override fun onCreate(db: SupportSQLiteDatabase) { super.onCreate(db) @@ -273,4 +279,4 @@ abstract class AppDatabase : RoomDatabase() { } } } -} \ No newline at end of file +} diff --git a/app/src/main/java/com/dark/tool_neuron/database/dao/McpServerDao.kt b/app/src/main/java/com/dark/tool_neuron/database/dao/McpServerDao.kt new file mode 100644 index 00000000..e33e7079 --- /dev/null +++ b/app/src/main/java/com/dark/tool_neuron/database/dao/McpServerDao.kt @@ -0,0 +1,45 @@ +package com.dark.tool_neuron.database.dao + +import androidx.room.* +import com.dark.tool_neuron.models.table_schema.McpServer +import kotlinx.coroutines.flow.Flow + +@Dao +interface McpServerDao { + + @Query("SELECT * FROM mcp_servers ORDER BY name ASC") + fun getAllServers(): Flow> + + @Query("SELECT * FROM mcp_servers WHERE isEnabled = 1 ORDER BY name ASC") + fun getEnabledServers(): Flow> + + @Query("SELECT * FROM mcp_servers WHERE id = :id") + suspend fun getServerById(id: String): McpServer? + + @Insert(onConflict = OnConflictStrategy.REPLACE) + suspend fun insertServer(server: McpServer) + + @Update + suspend fun updateServer(server: McpServer) + + @Delete + suspend fun deleteServer(server: McpServer) + + @Query("DELETE FROM mcp_servers WHERE id = :id") + suspend fun deleteServerById(id: String) + + @Query("UPDATE mcp_servers SET isEnabled = :isEnabled, updatedAt = :updatedAt WHERE id = :id") + suspend fun updateServerEnabled(id: String, isEnabled: Boolean, updatedAt: Long) + + @Query("UPDATE mcp_servers SET lastConnectedAt = :timestamp, updatedAt = :updatedAt WHERE id = :id") + suspend fun updateLastConnected(id: String, timestamp: Long, updatedAt: Long) + + @Query("SELECT COUNT(*) FROM mcp_servers") + fun getServerCount(): Flow + + @Query("SELECT COUNT(*) FROM mcp_servers WHERE isEnabled = 1") + fun getEnabledServerCount(): Flow + + @Query("SELECT * FROM mcp_servers ORDER BY name ASC") + suspend fun getAllServersSnapshot(): List +} diff --git a/app/src/main/java/com/dark/tool_neuron/di/AppContainer.kt b/app/src/main/java/com/dark/tool_neuron/di/AppContainer.kt index 4c3d16d6..99c445d4 100644 --- a/app/src/main/java/com/dark/tool_neuron/di/AppContainer.kt +++ b/app/src/main/java/com/dark/tool_neuron/di/AppContainer.kt @@ -6,7 +6,9 @@ import com.dark.tool_neuron.database.AppDatabase import com.dark.tool_neuron.database.dao.AiMemoryDao import com.dark.tool_neuron.database.dao.PersonaDao import com.dark.tool_neuron.repo.ChatRepository +import com.dark.tool_neuron.repo.McpServerRepository import com.dark.tool_neuron.repo.ModelRepository +import com.dark.tool_neuron.service.McpClientService import com.dark.tool_neuron.vault.VaultHelper import com.dark.tool_neuron.viewmodel.factory.ChatListViewModelFactory import com.dark.tool_neuron.viewmodel.factory.ChatViewModelFactory @@ -23,6 +25,8 @@ object AppContainer { private lateinit var database: AppDatabase private lateinit var modelRepository: ModelRepository private lateinit var chatRepository: ChatRepository + private lateinit var mcpServerRepository: McpServerRepository + private lateinit var mcpClientService: McpClientService private lateinit var llmModelViewModelFactory: LLMModelViewModelFactory private lateinit var chatListViewModelFactory: ChatListViewModelFactory private lateinit var chatViewModelFactory: ChatViewModelFactory @@ -43,6 +47,8 @@ object AppContainer { ) chatRepository = ChatRepository() + mcpServerRepository = McpServerRepository(database.mcpServerDao()) + mcpClientService = McpClientService() llmModelViewModelFactory = LLMModelViewModelFactory(application, modelRepository) chatListViewModelFactory = ChatListViewModelFactory(chatManager) @@ -94,6 +100,10 @@ object AppContainer { fun getChatRepository(): ChatRepository = chatRepository + fun getMcpServerRepository(): McpServerRepository = mcpServerRepository + + fun getMcpClientService(): McpClientService = mcpClientService + fun getLLMModelViewModelFactory(): LLMModelViewModelFactory = llmModelViewModelFactory fun getChatListViewModelFactory(): ChatListViewModelFactory = chatListViewModelFactory @@ -113,4 +123,4 @@ object AppContainer { fun getAiMemoryDao(): AiMemoryDao = database.aiMemoryDao() fun getGenerationManager(): GenerationManager = generationManager -} \ No newline at end of file +} diff --git a/app/src/main/java/com/dark/tool_neuron/di/HiltModules.kt b/app/src/main/java/com/dark/tool_neuron/di/HiltModules.kt index e8765672..bf835929 100644 --- a/app/src/main/java/com/dark/tool_neuron/di/HiltModules.kt +++ b/app/src/main/java/com/dark/tool_neuron/di/HiltModules.kt @@ -4,8 +4,11 @@ import com.dark.tool_neuron.database.AppDatabase import com.dark.tool_neuron.engine.EmbeddingEngine import com.dark.tool_neuron.repo.ChatRepository + import com.dark.tool_neuron.repo.McpServerRepository + import com.dark.tool_neuron.repo.McpStoreRepository import com.dark.tool_neuron.repo.ModelRepository import com.dark.tool_neuron.repo.RagRepository + import com.dark.tool_neuron.service.McpClientService import com.dark.tool_neuron.worker.ChatManager import com.dark.tool_neuron.worker.GenerationManager import com.dark.tool_neuron.worker.RagVaultIntegration @@ -65,6 +68,26 @@ context = context ) } + + @Provides + @Singleton + fun provideMcpServerRepository(database: AppDatabase): McpServerRepository { + return McpServerRepository( + mcpServerDao = database.mcpServerDao() + ) + } + + @Provides + @Singleton + fun provideMcpStoreRepository( + @ApplicationContext context: Context, + mcpServerRepository: McpServerRepository + ): McpStoreRepository { + return McpStoreRepository( + context = context, + mcpServerRepository = mcpServerRepository + ) + } } @Module @@ -78,6 +101,17 @@ } } + @Module + @InstallIn(SingletonComponent::class) + object ServiceModule { + + @Provides + @Singleton + fun provideMcpClientService(): McpClientService { + return McpClientService() + } + } + @Module @InstallIn(SingletonComponent::class) object WorkerModule { diff --git a/app/src/main/java/com/dark/tool_neuron/models/McpStoreEntry.kt b/app/src/main/java/com/dark/tool_neuron/models/McpStoreEntry.kt new file mode 100644 index 00000000..37ab4d45 --- /dev/null +++ b/app/src/main/java/com/dark/tool_neuron/models/McpStoreEntry.kt @@ -0,0 +1,42 @@ +package com.dark.tool_neuron.models + +import kotlinx.serialization.Serializable + +/** + * Represents an MCP server entry in the remote registry (MCP Store). + * Users browse these entries and install them as local McpServer configurations. + */ +@Serializable +data class McpStoreEntry( + val id: String, + val name: String, + val description: String, + val url: String, + val transportType: String = "SSE", + val category: String = "general", + val requiresApiKey: Boolean = false, + val requiresTermux: Boolean = false, + val pipPackage: String? = null, + val setupCommand: String? = null, + val defaultPort: Int? = null, + val author: String = "", + val tags: List = emptyList(), + val iconName: String? = null, + val setupInstructions: String? = null +) + +/** + * Categories for MCP Store entries + */ +object McpStoreCategories { + const val ALL = "All" + const val SEARCH = "Search" + const val CODE = "Code" + const val DATA = "Data" + const val FILES = "Files" + const val AI = "AI" + const val UTILITIES = "Utilities" + const val LOCAL = "Local (Termux)" + + val all = listOf(ALL, SEARCH, CODE, DATA, FILES, AI, UTILITIES, LOCAL) +} diff --git a/app/src/main/java/com/dark/tool_neuron/models/converters/Converters.kt b/app/src/main/java/com/dark/tool_neuron/models/converters/Converters.kt index 1ed4e662..95958658 100644 --- a/app/src/main/java/com/dark/tool_neuron/models/converters/Converters.kt +++ b/app/src/main/java/com/dark/tool_neuron/models/converters/Converters.kt @@ -3,6 +3,7 @@ package com.dark.tool_neuron.models.converters import androidx.room.TypeConverter import com.dark.tool_neuron.models.enums.PathType import com.dark.tool_neuron.models.enums.ProviderType +import com.dark.tool_neuron.models.table_schema.McpTransportType import org.json.JSONArray class Converters { @@ -18,6 +19,12 @@ class Converters { @TypeConverter fun toPathType(value: String): PathType = PathType.valueOf(value) + @TypeConverter + fun fromMcpTransportType(value: McpTransportType): String = value.name + + @TypeConverter + fun toMcpTransportType(value: String): McpTransportType = McpTransportType.valueOf(value) + @TypeConverter fun fromStringList(value: List): String = JSONArray(value).toString() diff --git a/app/src/main/java/com/dark/tool_neuron/models/table_schema/McpServer.kt b/app/src/main/java/com/dark/tool_neuron/models/table_schema/McpServer.kt new file mode 100644 index 00000000..3fed8ce6 --- /dev/null +++ b/app/src/main/java/com/dark/tool_neuron/models/table_schema/McpServer.kt @@ -0,0 +1,75 @@ +package com.dark.tool_neuron.models.table_schema + +import androidx.room.Entity +import androidx.room.PrimaryKey + +/** + * Transport type for MCP server connections + */ +enum class McpTransportType { + SSE, // Server-Sent Events (HTTP) + STREAMABLE_HTTP // Streamable HTTP transport +} + +/** + * Connection status of an MCP server (runtime only, not persisted) + */ +enum class McpConnectionStatus { + DISCONNECTED, + CONNECTING, + CONNECTED, + ERROR +} + +/** + * Entity representing a remote MCP (Model Context Protocol) server configuration. + * MCP servers provide tools, resources, and prompts to LLM applications. + */ +@Entity(tableName = "mcp_servers") +data class McpServer( + @PrimaryKey + val id: String, + + /** Display name for the server */ + val name: String, + + /** Server URL (e.g., "https://api.example.com/mcp") */ + val url: String, + + /** Transport type for the connection */ + val transportType: McpTransportType = McpTransportType.SSE, + + /** Optional API key for authentication */ + val apiKey: String? = null, + + /** Whether the server is enabled */ + val isEnabled: Boolean = true, + + /** Last error message if connection failed */ + val lastError: String? = null, + + /** Timestamp when the server was added */ + val createdAt: Long = System.currentTimeMillis(), + + /** Timestamp when the server was last modified */ + val updatedAt: Long = System.currentTimeMillis(), + + /** Timestamp when last successfully connected */ + val lastConnectedAt: Long? = null, + + /** Optional description */ + val description: String = "", + + /** Custom headers as JSON string (e.g., for additional auth) */ + val customHeadersJson: String? = null, + + /** Whether this server runs locally (e.g., via Termux) */ + val isLocal: Boolean = false, + + /** ID of the MCP Store entry this server was installed from */ + val sourceStoreId: String? = null +) { + companion object { + fun generateId(): String = java.util.UUID.randomUUID().toString() + } +} diff --git a/app/src/main/java/com/dark/tool_neuron/repo/McpServerRepository.kt b/app/src/main/java/com/dark/tool_neuron/repo/McpServerRepository.kt new file mode 100644 index 00000000..c1eefb0d --- /dev/null +++ b/app/src/main/java/com/dark/tool_neuron/repo/McpServerRepository.kt @@ -0,0 +1,165 @@ +package com.dark.tool_neuron.repo + +import android.util.Log +import com.dark.tool_neuron.database.dao.McpServerDao +import com.dark.tool_neuron.models.table_schema.McpConnectionStatus +import com.dark.tool_neuron.models.table_schema.McpServer +import com.dark.tool_neuron.models.table_schema.McpTransportType +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.asStateFlow +import java.net.URI +import javax.inject.Inject +import javax.inject.Singleton + +/** + * Repository for managing MCP (Model Context Protocol) server configurations + */ +@Singleton +class McpServerRepository @Inject constructor( + private val mcpServerDao: McpServerDao +) { + companion object { + private const val TAG = "McpServerRepository" + } + // Runtime connection status tracking (not persisted) + private val _connectionStatuses = MutableStateFlow>(emptyMap()) + val connectionStatuses: StateFlow> = _connectionStatuses.asStateFlow() + + /** + * Get all configured MCP servers + */ + fun getAllServers(): Flow> = mcpServerDao.getAllServers() + + /** + * Get only enabled MCP servers + */ + fun getEnabledServers(): Flow> = mcpServerDao.getEnabledServers() + + /** + * Get a specific server by ID + */ + suspend fun getServerById(id: String): McpServer? = mcpServerDao.getServerById(id) + + /** + * Add a new MCP server + * @throws IllegalArgumentException if the URL is not valid + */ + suspend fun addServer( + name: String, + url: String, + transportType: McpTransportType = McpTransportType.SSE, + apiKey: String? = null, + description: String = "" + ): McpServer { + val trimmedUrl = url.trim() + + // Validate URL format + val validatedUrl = try { + val uri = URI(trimmedUrl) + if (uri.scheme.isNullOrBlank() || uri.host.isNullOrBlank()) { + throw IllegalArgumentException("Invalid server URL: missing scheme or host") + } + if (uri.scheme != "http" && uri.scheme != "https") { + throw IllegalArgumentException("Invalid server URL scheme: ${uri.scheme}") + } + trimmedUrl + } catch (e: IllegalArgumentException) { + throw e + } catch (e: Exception) { + throw IllegalArgumentException("Invalid server URL format: '$trimmedUrl'", e) + } + + val server = McpServer( + id = McpServer.generateId(), + name = name, + url = validatedUrl, + transportType = transportType, + apiKey = apiKey?.trim()?.takeIf { it.isNotEmpty() }, + description = description.trim(), + isEnabled = true, + createdAt = System.currentTimeMillis(), + updatedAt = System.currentTimeMillis() + ) + mcpServerDao.insertServer(server) + return server + } + + /** + * Update an existing MCP server + */ + suspend fun updateServer(server: McpServer) { + mcpServerDao.updateServer(server.copy(updatedAt = System.currentTimeMillis())) + } + + /** + * Delete an MCP server + */ + suspend fun deleteServer(id: String) { + mcpServerDao.deleteServerById(id) + // Remove from runtime status tracking + _connectionStatuses.value = _connectionStatuses.value - id + } + + /** + * Toggle server enabled/disabled state + */ + suspend fun setServerEnabled(id: String, enabled: Boolean) { + mcpServerDao.updateServerEnabled(id, enabled, System.currentTimeMillis()) + if (!enabled) { + // When disabled, set status to disconnected + updateConnectionStatus(id, McpConnectionStatus.DISCONNECTED) + } + } + + /** + * Update the runtime connection status of a server + * @param serverId The ID of the server + * @param status The new connection status + * @param error Optional error message when status is ERROR + */ + fun updateConnectionStatus(serverId: String, status: McpConnectionStatus, error: String? = null) { + if (error != null && status == McpConnectionStatus.ERROR) { + Log.w(TAG, "MCP server $serverId connection error: $error") + } + _connectionStatuses.value = _connectionStatuses.value + (serverId to status) + } + + /** + * Update last connected timestamp + */ + suspend fun updateLastConnected(id: String) { + val now = System.currentTimeMillis() + mcpServerDao.updateLastConnected(id, now, now) + } + + /** + * Get the count of all servers + */ + fun getServerCount(): Flow = mcpServerDao.getServerCount() + + /** + * Get the count of enabled servers + */ + fun getEnabledServerCount(): Flow = mcpServerDao.getEnabledServerCount() + + /** + * Get all servers as a one-shot snapshot (non-Flow) + */ + suspend fun getAllServersSnapshot(): List = mcpServerDao.getAllServersSnapshot() + + /** + * Get the current connection status for a server + */ + fun getConnectionStatus(serverId: String): McpConnectionStatus { + return _connectionStatuses.value[serverId] ?: McpConnectionStatus.DISCONNECTED + } + + /** + * Add a pre-built McpServer directly (used by MCP Store). + */ + suspend fun addServerDirect(server: McpServer) { + mcpServerDao.insertServer(server) + } +} diff --git a/app/src/main/java/com/dark/tool_neuron/repo/McpStoreRepository.kt b/app/src/main/java/com/dark/tool_neuron/repo/McpStoreRepository.kt new file mode 100644 index 00000000..442e3ec2 --- /dev/null +++ b/app/src/main/java/com/dark/tool_neuron/repo/McpStoreRepository.kt @@ -0,0 +1,139 @@ +package com.dark.tool_neuron.repo + +import android.content.Context +import android.util.Log +import com.dark.tool_neuron.models.McpStoreCategories +import com.dark.tool_neuron.models.McpStoreEntry +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.asStateFlow +import kotlinx.coroutines.withContext +import kotlinx.serialization.json.Json +import okhttp3.OkHttpClient +import okhttp3.Request +import javax.inject.Inject +import javax.inject.Singleton + +/** + * Repository for fetching and managing MCP Store registry entries. + * Loads a bundled fallback from assets and can refresh from a remote URL. + */ +@Singleton +class McpStoreRepository @Inject constructor( + private val context: Context, + private val mcpServerRepository: McpServerRepository +) { + companion object { + private const val TAG = "McpStoreRepository" + private const val REGISTRY_ASSET = "mcp-registry.json" + private const val REMOTE_REGISTRY_URL = + "https://raw.githubusercontent.com/Siddhesh2377/ToolNeuron/re-write/app/src/main/assets/mcp-registry.json" + } + + private val json = Json { ignoreUnknownKeys = true } + private val client = OkHttpClient() + + private val _entries = MutableStateFlow>(emptyList()) + val entries: StateFlow> = _entries.asStateFlow() + + private val _isLoading = MutableStateFlow(false) + val isLoading: StateFlow = _isLoading.asStateFlow() + + private val _error = MutableStateFlow(null) + val error: StateFlow = _error.asStateFlow() + + /** + * Load registry entries. Tries remote first, falls back to bundled asset. + */ + suspend fun loadEntries() { + if (_entries.value.isNotEmpty() && !_isLoading.value) return + _isLoading.value = true + _error.value = null + try { + val remote = fetchRemoteRegistry() + if (remote != null && remote.isNotEmpty()) { + _entries.value = remote + Log.d(TAG, "Loaded ${remote.size} entries from remote registry") + } else { + val local = loadBundledRegistry() + _entries.value = local + Log.d(TAG, "Loaded ${local.size} entries from bundled registry") + } + } catch (e: Exception) { + Log.e(TAG, "Failed to load registry, falling back to bundled", e) + try { + _entries.value = loadBundledRegistry() + } catch (e2: Exception) { + _error.value = "Failed to load MCP store: ${e2.message}" + Log.e(TAG, "Failed to load bundled registry", e2) + } + } finally { + _isLoading.value = false + } + } + + /** + * Force refresh from remote registry. + */ + suspend fun refresh() { + _entries.value = emptyList() + loadEntries() + } + + /** + * Filter entries by category and search query. + */ + fun filterEntries( + entries: List, + category: String, + searchQuery: String + ): List { + return entries.filter { entry -> + val matchesCategory = category == McpStoreCategories.ALL || + entry.category.equals(category, ignoreCase = true) || + (category == McpStoreCategories.LOCAL && entry.requiresTermux) + val matchesSearch = searchQuery.isBlank() || + entry.name.contains(searchQuery, ignoreCase = true) || + entry.description.contains(searchQuery, ignoreCase = true) || + entry.tags.any { it.contains(searchQuery, ignoreCase = true) } + matchesCategory && matchesSearch + } + } + + /** + * Check if a store entry is already installed as an MCP server. + */ + suspend fun isInstalled(storeEntryId: String): Boolean { + // Check all servers for a matching sourceStoreId + val servers = mcpServerRepository.getAllServersSnapshot() + return servers.any { it.sourceStoreId == storeEntryId } + } + + fun clearError() { + _error.value = null + } + + private suspend fun fetchRemoteRegistry(): List? = withContext(Dispatchers.IO) { + try { + val request = Request.Builder().url(REMOTE_REGISTRY_URL).build() + val response = client.newCall(request).execute() + if (response.isSuccessful) { + val body = response.body?.string() ?: return@withContext null + json.decodeFromString>(body) + } else { + Log.w(TAG, "Remote registry returned ${response.code}") + null + } + } catch (e: Exception) { + Log.w(TAG, "Could not fetch remote registry: ${e.message}") + null + } + } + + private suspend fun loadBundledRegistry(): List = withContext(Dispatchers.IO) { + val inputStream = context.assets.open(REGISTRY_ASSET) + val jsonString = inputStream.bufferedReader().use { it.readText() } + json.decodeFromString(jsonString) + } +} diff --git a/app/src/main/java/com/dark/tool_neuron/service/McpClientService.kt b/app/src/main/java/com/dark/tool_neuron/service/McpClientService.kt new file mode 100644 index 00000000..d5b07b1f --- /dev/null +++ b/app/src/main/java/com/dark/tool_neuron/service/McpClientService.kt @@ -0,0 +1,377 @@ +package com.dark.tool_neuron.service + +import android.util.Log +import com.dark.tool_neuron.models.table_schema.McpServer +import com.dark.tool_neuron.models.table_schema.McpTransportType +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.withContext +import okhttp3.MediaType.Companion.toMediaType +import okhttp3.OkHttpClient +import okhttp3.Request +import okhttp3.RequestBody.Companion.toRequestBody +import org.json.JSONObject +import java.util.concurrent.TimeUnit +import javax.inject.Inject +import javax.inject.Singleton + +/** + * MCP Client response data + */ +data class McpToolInfo( + val name: String, + val description: String?, + val inputSchema: String? +) + +data class McpTestResult( + val success: Boolean, + val message: String, + val tools: List = emptyList(), + val serverInfo: String? = null +) + +/** + * Client service for connecting to remote MCP (Model Context Protocol) servers. + * Supports both SSE (Server-Sent Events) and Streamable HTTP transport types. + * + * Transport Types: + * - SSE: Uses text/event-stream for responses (commonly used by servers like Zapier MCP) + * - Streamable HTTP: Uses standard JSON responses + */ +@Singleton +class McpClientService @Inject constructor() { + + companion object { + private const val TAG = "McpClientService" + private const val CONNECT_TIMEOUT_SECONDS = 15L + private const val READ_TIMEOUT_SECONDS = 30L + private const val MCP_PROTOCOL_VERSION = "2024-11-05" + private const val CLIENT_NAME = "ToolNeuron" + private const val CLIENT_VERSION = "1.0.0" + private val JSON_MEDIA_TYPE = "application/json".toMediaType() + // Accept headers for different transport types + private const val ACCEPT_HEADER_SSE = "application/json, text/event-stream" + private const val ACCEPT_HEADER_HTTP = "application/json" + } + + private val httpClient = OkHttpClient.Builder() + .connectTimeout(CONNECT_TIMEOUT_SECONDS, TimeUnit.SECONDS) + .readTimeout(READ_TIMEOUT_SECONDS, TimeUnit.SECONDS) + .build() + + /** + * Clean up idle resources without destroying the singleton OkHttpClient. + * Only evicts idle connections — the client remains usable for future requests. + */ + fun close() { + try { + httpClient.connectionPool.evictAll() + httpClient.cache?.close() + } catch (e: Exception) { + Log.w(TAG, "Error while closing OkHttpClient resources", e) + } + } + + /** + * Get the appropriate Accept header based on transport type + */ + private fun getAcceptHeader(transportType: McpTransportType): String { + return when (transportType) { + McpTransportType.SSE -> ACCEPT_HEADER_SSE + McpTransportType.STREAMABLE_HTTP -> ACCEPT_HEADER_HTTP + } + } + + /** + * Parse response body, handling SSE format automatically. + * Some MCP servers return SSE-formatted responses regardless of the declared transport type, + * so we detect and parse SSE format for both transport types. + */ + private fun parseResponse(responseBody: String, transportType: McpTransportType): String { + // Always try to parse SSE format first, as some servers return SSE regardless of transport type + // The parseSseResponse function will return the original body if it's not SSE format + return parseSseResponse(responseBody) + } + + /** + * Parse SSE (Server-Sent Events) response format. + * SSE responses come as "event: message\ndata: {...json...}\n\n" + * This handles single-event responses commonly used in MCP request/response patterns. + * + * Note: For streaming scenarios, this parser extracts the last complete event. + * In MCP's request/response pattern, this is typically the only event. + */ + private fun parseSseResponse(responseBody: String): String { + // Check if this is an SSE response + if (!responseBody.contains("data:")) { + // Not SSE format, return as-is + return responseBody + } + + // Split by double newlines to separate events + val events = responseBody.split("\n\n") + + // Find the last event with data (for request/response pattern) + for (event in events.reversed()) { + val lines = event.lines() + val dataLines = lines.filter { it.startsWith("data:") } + + if (dataLines.isNotEmpty()) { + // Extract JSON from "data: {...}" format + // Multiple data lines in same event should be joined with newlines per SSE spec + val joinedData = dataLines.joinToString("\n") { it.removePrefix("data:").trim() } + + // Validate that the joined data is a valid JSON-RPC response + return try { + val json = JSONObject(joinedData) + if (!json.has("jsonrpc") || (!json.has("result") && !json.has("error") && !json.has("id"))) { + Log.w(TAG, "SSE data missing JSON-RPC fields; returning raw response") + responseBody + } else { + joinedData + } + } catch (e: Exception) { + Log.w(TAG, "SSE data is not valid JSON; returning raw SSE response body", e) + responseBody + } + } + } + + // Fallback: return original response + return responseBody + } + + /** + * Test connection to an MCP server and retrieve server capabilities + */ + suspend fun testConnection(server: McpServer): McpTestResult = withContext(Dispatchers.IO) { + try { + Log.d(TAG, "Testing connection to MCP server: ${server.name} at ${server.url} (transport: ${server.transportType})") + + // Build the initialize request according to MCP protocol + val initializeRequest = JSONObject().apply { + put("jsonrpc", "2.0") + put("id", 1) + put("method", "initialize") + put("params", JSONObject().apply { + put("protocolVersion", MCP_PROTOCOL_VERSION) + put("capabilities", JSONObject()) + put("clientInfo", JSONObject().apply { + put("name", CLIENT_NAME) + put("version", CLIENT_VERSION) + }) + }) + } + + val requestBuilder = Request.Builder() + .url(server.url) + .post(initializeRequest.toString().toRequestBody(JSON_MEDIA_TYPE)) + .addHeader("Content-Type", "application/json") + .addHeader("Accept", getAcceptHeader(server.transportType)) + + // Add API key if provided + server.apiKey?.let { key -> + requestBuilder.addHeader("Authorization", "Bearer $key") + } + + val response = httpClient.newCall(requestBuilder.build()).execute() + + if (!response.isSuccessful) { + return@withContext McpTestResult( + success = false, + message = "Server returned error: ${response.code} ${response.message}" + ) + } + + val rawResponseBody = response.body?.string() + if (rawResponseBody.isNullOrBlank()) { + return@withContext McpTestResult( + success = false, + message = "Server returned empty response" + ) + } + + // Parse response based on transport type + val responseBody = parseResponse(rawResponseBody, server.transportType) + + // Parse JSON response with specific error handling + val jsonResponse = try { + JSONObject(responseBody) + } catch (e: org.json.JSONException) { + Log.e(TAG, "Failed to parse MCP response as JSON: ${e.message}") + return@withContext McpTestResult( + success = false, + message = "Server returned invalid JSON response. The server may not be a valid MCP server." + ) + } + + // Check for JSON-RPC error + if (jsonResponse.has("error")) { + val error = jsonResponse.getJSONObject("error") + return@withContext McpTestResult( + success = false, + message = "Server error: ${error.optString("message", "Unknown error")}" + ) + } + + // Parse the result + val result = jsonResponse.optJSONObject("result") + val serverInfo = result?.optJSONObject("serverInfo") + val serverName = serverInfo?.optString("name", "Unknown Server") ?: "Unknown Server" + val serverVersion = serverInfo?.optString("version", "") ?: "" + + // Now list available tools + val tools = listTools(server) + + McpTestResult( + success = true, + message = "Connected successfully", + tools = tools, + serverInfo = if (serverVersion.isNotEmpty()) "$serverName v$serverVersion" else serverName + ) + + } catch (e: Exception) { + Log.e(TAG, "Failed to connect to MCP server: ${e.message}", e) + McpTestResult( + success = false, + message = "Connection failed: ${e.message ?: "Unknown error"}" + ) + } + } + + /** + * List available tools from an MCP server + */ + suspend fun listTools(server: McpServer): List = withContext(Dispatchers.IO) { + try { + val listToolsRequest = JSONObject().apply { + put("jsonrpc", "2.0") + put("id", 2) + put("method", "tools/list") + put("params", JSONObject()) + } + + val requestBuilder = Request.Builder() + .url(server.url) + .post(listToolsRequest.toString().toRequestBody(JSON_MEDIA_TYPE)) + .addHeader("Content-Type", "application/json") + .addHeader("Accept", getAcceptHeader(server.transportType)) + + server.apiKey?.let { key -> + requestBuilder.addHeader("Authorization", "Bearer $key") + } + + val response = httpClient.newCall(requestBuilder.build()).execute() + + if (!response.isSuccessful) { + Log.w(TAG, "listTools failed for '${server.name}': HTTP ${response.code} ${response.message}") + return@withContext emptyList() + } + + val rawResponseBody = response.body?.string() ?: return@withContext emptyList() + // Parse response based on transport type + val responseBody = parseResponse(rawResponseBody, server.transportType) + val jsonResponse = JSONObject(responseBody) + + if (jsonResponse.has("error")) { + val error = jsonResponse.getJSONObject("error") + Log.w(TAG, "listTools JSON-RPC error for '${server.name}': ${error.optString("message", "Unknown")}") + return@withContext emptyList() + } + + val result = jsonResponse.optJSONObject("result") ?: return@withContext emptyList() + val toolsArray = result.optJSONArray("tools") ?: return@withContext emptyList() + + val tools = mutableListOf() + for (i in 0 until toolsArray.length()) { + val tool = toolsArray.getJSONObject(i) + tools.add(McpToolInfo( + name = tool.getString("name"), + description = tool.optString("description", null), + inputSchema = tool.optJSONObject("inputSchema")?.toString() + )) + } + + tools + + } catch (e: Exception) { + Log.e(TAG, "Failed to list tools for '${server.name}': ${e.message}", e) + emptyList() + } + } + + /** + * Call a tool on an MCP server + */ + suspend fun callTool( + server: McpServer, + toolName: String, + arguments: Map + ): Result = callToolInternal(server, toolName, JSONObject(arguments)) + + suspend fun callTool( + server: McpServer, + toolName: String, + argumentsJson: String + ): Result { + val parsedArguments = try { + if (argumentsJson.isBlank()) JSONObject() else JSONObject(argumentsJson) + } catch (e: Exception) { + return Result.failure(Exception("Invalid tool arguments JSON: ${e.message}")) + } + return callToolInternal(server, toolName, parsedArguments) + } + + private suspend fun callToolInternal( + server: McpServer, + toolName: String, + arguments: JSONObject + ): Result = withContext(Dispatchers.IO) { + try { + val callToolRequest = JSONObject().apply { + put("jsonrpc", "2.0") + put("id", System.currentTimeMillis()) + put("method", "tools/call") + put("params", JSONObject().apply { + put("name", toolName) + put("arguments", arguments) + }) + } + + val requestBuilder = Request.Builder() + .url(server.url) + .post(callToolRequest.toString().toRequestBody(JSON_MEDIA_TYPE)) + .addHeader("Content-Type", "application/json") + .addHeader("Accept", getAcceptHeader(server.transportType)) + + server.apiKey?.let { key -> + requestBuilder.addHeader("Authorization", "Bearer $key") + } + + val response = httpClient.newCall(requestBuilder.build()).execute() + + if (!response.isSuccessful) { + return@withContext Result.failure(Exception("Server returned: ${response.code}")) + } + + val rawResponseBody = response.body?.string() + ?: return@withContext Result.failure(Exception("Empty response")) + + // Parse response based on transport type + val responseBody = parseResponse(rawResponseBody, server.transportType) + val jsonResponse = JSONObject(responseBody) + + if (jsonResponse.has("error")) { + val error = jsonResponse.getJSONObject("error") + return@withContext Result.failure(Exception(error.optString("message", "Unknown error"))) + } + + val result = jsonResponse.optJSONObject("result") + return@withContext Result.success(result?.toString() ?: responseBody) + + } catch (e: Exception) { + Log.e(TAG, "Failed to call tool: ${e.message}", e) + return@withContext Result.failure(e) + } + } +} diff --git a/app/src/main/java/com/dark/tool_neuron/service/McpToolMapper.kt b/app/src/main/java/com/dark/tool_neuron/service/McpToolMapper.kt new file mode 100644 index 00000000..9973b526 --- /dev/null +++ b/app/src/main/java/com/dark/tool_neuron/service/McpToolMapper.kt @@ -0,0 +1,69 @@ +package com.dark.tool_neuron.service + +import com.dark.tool_neuron.models.table_schema.McpServer +import org.json.JSONArray +import org.json.JSONObject + +data class McpToolReference( + val server: McpServer, + val toolName: String +) + +data class McpToolMapping( + val toolsJson: String, + val toolRegistry: Map +) + +object McpToolMapper { + fun sanitizeIdentifier(value: String): String { + return value.lowercase() + .replace(Regex("[^a-z0-9]+"), "_") + .replace(Regex("_+"), "_") + .trim('_') + } + + fun buildMapping(serverTools: Map>): McpToolMapping { + val toolsArray = JSONArray() + val registry = mutableMapOf() + + serverTools.forEach { (server, tools) -> + val serverPrefix = sanitizeIdentifier(server.name).ifBlank { "mcp" } + tools.forEach { tool -> + val toolSlug = sanitizeIdentifier(tool.name).ifBlank { "tool" } + val toolId = "${serverPrefix}_${toolSlug}" + toolsArray.put(buildToolDefinition(toolId, tool)) + registry[toolId] = McpToolReference(server, tool.name) + } + } + + return McpToolMapping( + toolsJson = toolsArray.toString(), + toolRegistry = registry + ) + } + + private fun buildToolDefinition(toolId: String, tool: McpToolInfo): JSONObject { + val function = JSONObject().apply { + put("name", toolId) + tool.description?.takeIf { it.isNotBlank() }?.let { put("description", it) } + put("parameters", buildParameters(tool.inputSchema)) + } + + return JSONObject().apply { + put("type", "function") + put("function", function) + } + } + + private fun buildParameters(inputSchema: String?): JSONObject { + val parsedSchema = inputSchema?.takeIf { it.isNotBlank() }?.let { + runCatching { JSONObject(it) }.getOrNull() + } + + return (parsedSchema ?: JSONObject()).apply { + if (!has("type")) { + put("type", "object") + } + } + } +} diff --git a/app/src/main/java/com/dark/tool_neuron/service/TermuxBridge.kt b/app/src/main/java/com/dark/tool_neuron/service/TermuxBridge.kt new file mode 100644 index 00000000..afd5eba9 --- /dev/null +++ b/app/src/main/java/com/dark/tool_neuron/service/TermuxBridge.kt @@ -0,0 +1,144 @@ +package com.dark.tool_neuron.service + +import android.content.Context +import android.content.Intent +import android.content.pm.PackageManager +import android.util.Log + +/** + * Bridge for interacting with the Termux app to run local MCP servers. + * Uses Termux's RUN_COMMAND intent API to execute commands. + */ +object TermuxBridge { + private const val TAG = "TermuxBridge" + + const val TERMUX_PACKAGE = "com.termux" + private const val RUN_COMMAND_SERVICE = "com.termux.app.RunCommandService" + private const val RUN_COMMAND_ACTION = "com.termux.RUN_COMMAND" + const val RUN_COMMAND_PERMISSION = "com.termux.permission.RUN_COMMAND" + + private const val EXTRA_COMMAND_PATH = "com.termux.RUN_COMMAND_PATH" + private const val EXTRA_ARGUMENTS = "com.termux.RUN_COMMAND_ARGUMENTS" + private const val EXTRA_WORKDIR = "com.termux.RUN_COMMAND_WORKDIR" + private const val EXTRA_BACKGROUND = "com.termux.RUN_COMMAND_BACKGROUND" + private const val EXTRA_SESSION_ACTION = "com.termux.RUN_COMMAND_SESSION_ACTION" + + const val FDROID_URL = "https://f-droid.org/packages/com.termux/" + const val GITHUB_URL = "https://github.com/termux/termux-app/releases" + + /** + * Check if Termux is installed on the device. + */ + fun isTermuxInstalled(context: Context): Boolean { + return try { + context.packageManager.getPackageInfo(TERMUX_PACKAGE, 0) + true + } catch (e: PackageManager.NameNotFoundException) { + false + } + } + + /** + * Check if the app has the RUN_COMMAND permission for Termux. + */ + fun hasRunCommandPermission(context: Context): Boolean { + return context.checkSelfPermission(RUN_COMMAND_PERMISSION) == + PackageManager.PERMISSION_GRANTED + } + + /** + * Launch the Termux app. + */ + fun launchTermux(context: Context) { + val intent = context.packageManager.getLaunchIntentForPackage(TERMUX_PACKAGE) + if (intent != null) { + intent.addFlags(Intent.FLAG_ACTIVITY_NEW_TASK) + context.startActivity(intent) + } else { + Log.w(TAG, "Could not launch Termux — no launch intent found") + } + } + + /** + * Run a command in Termux via the RUN_COMMAND intent. + * + * @param context Android context + * @param command The executable path (e.g., "/data/data/com.termux/files/usr/bin/python") + * @param arguments Command arguments + * @param background If true, runs in background without a terminal session + * @param workdir Working directory for the command + */ + fun runCommand( + context: Context, + command: String, + arguments: Array = emptyArray(), + background: Boolean = true, + workdir: String? = null + ) { + val intent = Intent(RUN_COMMAND_ACTION).apply { + setClassName(TERMUX_PACKAGE, RUN_COMMAND_SERVICE) + putExtra(EXTRA_COMMAND_PATH, command) + putExtra(EXTRA_ARGUMENTS, arguments) + putExtra(EXTRA_BACKGROUND, background) + if (workdir != null) { + putExtra(EXTRA_WORKDIR, workdir) + } + // 0 = open new session, 1 = attach to current, 2 = do nothing + putExtra(EXTRA_SESSION_ACTION, "0") + } + try { + context.startForegroundService(intent) + Log.d(TAG, "Sent RUN_COMMAND to Termux: $command ${arguments.joinToString(" ")}") + } catch (e: Exception) { + Log.e(TAG, "Failed to send RUN_COMMAND to Termux", e) + } + } + + /** + * Install a pip package in Termux. + */ + fun pipInstall(context: Context, packageName: String) { + runCommand( + context = context, + command = "/data/data/com.termux/files/usr/bin/bash", + arguments = arrayOf("-c", "pip install $packageName"), + background = false + ) + } + + /** + * Start a Python-based MCP server in Termux. + * + * @param context Android context + * @param pipPackage The pip package name of the MCP server + * @param port The port to run the server on + * @param extraArgs Additional arguments for the server command + */ + fun startMcpServer( + context: Context, + pipPackage: String, + port: Int, + extraArgs: Array = emptyArray() + ) { + val serverCmd = buildString { + append("python -m $pipPackage --port $port") + if (extraArgs.isNotEmpty()) { + append(" ") + append(extraArgs.joinToString(" ")) + } + } + runCommand( + context = context, + command = "/data/data/com.termux/files/usr/bin/bash", + arguments = arrayOf("-c", serverCmd), + background = true + ) + } + + /** + * Get the localhost URL for a Termux-hosted MCP server. + */ + fun getLocalServerUrl(port: Int, path: String = "/sse"): String { + return "http://127.0.0.1:$port$path" + } +} diff --git a/app/src/main/java/com/dark/tool_neuron/ui/screen/McpServersScreen.kt b/app/src/main/java/com/dark/tool_neuron/ui/screen/McpServersScreen.kt new file mode 100644 index 00000000..824770e8 --- /dev/null +++ b/app/src/main/java/com/dark/tool_neuron/ui/screen/McpServersScreen.kt @@ -0,0 +1,871 @@ +package com.dark.tool_neuron.ui.screen + +import androidx.compose.animation.* +import androidx.compose.animation.core.* +import androidx.compose.foundation.background +import androidx.compose.foundation.clickable +import androidx.compose.foundation.layout.* +import androidx.compose.foundation.lazy.LazyColumn +import androidx.compose.foundation.lazy.items +import androidx.compose.foundation.shape.CircleShape +import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.foundation.text.KeyboardOptions +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.* +import androidx.compose.material3.* +import androidx.compose.runtime.* +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.draw.clip +import androidx.compose.ui.graphics.Color +import androidx.compose.ui.text.font.FontWeight +import androidx.compose.ui.text.input.KeyboardType +import androidx.compose.ui.text.input.PasswordVisualTransformation +import androidx.compose.ui.text.input.VisualTransformation +import androidx.compose.ui.text.style.TextOverflow +import androidx.compose.ui.unit.dp +import androidx.hilt.navigation.compose.hiltViewModel +import androidx.lifecycle.compose.collectAsStateWithLifecycle +import com.dark.tool_neuron.models.table_schema.McpConnectionStatus +import com.dark.tool_neuron.models.table_schema.McpServer +import com.dark.tool_neuron.models.table_schema.McpTransportType +import com.dark.tool_neuron.service.McpTestResult +import com.dark.tool_neuron.ui.components.ActionButton +import com.dark.tool_neuron.ui.components.ActionTextButton +import com.dark.tool_neuron.ui.components.CuteSwitch +import com.dark.tool_neuron.ui.theme.rDp +import com.dark.tool_neuron.viewmodel.McpServerUiState +import com.dark.tool_neuron.viewmodel.McpServerViewModel +import java.text.SimpleDateFormat +import java.util.* + +// Success color for connected/successful states +private val SuccessGreen = Color(0xFF4CAF50) + +@OptIn(ExperimentalMaterial3Api::class) +@Composable +fun McpServersScreen( + onBackClick: () -> Unit, + onStoreClick: () -> Unit = {}, + viewModel: McpServerViewModel = hiltViewModel() +) { + val servers by viewModel.servers.collectAsStateWithLifecycle() + val serverCount by viewModel.serverCount.collectAsStateWithLifecycle() + val enabledServerCount by viewModel.enabledServerCount.collectAsStateWithLifecycle() + val showAddDialog by viewModel.showAddDialog.collectAsStateWithLifecycle() + val showEditDialog by viewModel.showEditDialog.collectAsStateWithLifecycle() + val selectedServer by viewModel.selectedServer.collectAsStateWithLifecycle() + val testingServerId by viewModel.testingServerId.collectAsStateWithLifecycle() + val testResult by viewModel.testResult.collectAsStateWithLifecycle() + val isLoading by viewModel.isLoading.collectAsStateWithLifecycle() + val error by viewModel.error.collectAsStateWithLifecycle() + + Scaffold( + topBar = { + CenterAlignedTopAppBar( + title = { + Column(horizontalAlignment = Alignment.CenterHorizontally) { + Text( + "MCP Servers", + style = MaterialTheme.typography.titleMedium, + fontWeight = FontWeight.SemiBold + ) + Text( + "$enabledServerCount active / $serverCount total", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + }, + navigationIcon = { + ActionTextButton( + onClickListener = onBackClick, + icon = Icons.Default.ChevronLeft, + text = "Back", + modifier = Modifier.padding(start = rDp(6.dp)) + ) + }, + actions = { + ActionButton( + onClickListener = onStoreClick, + icon = Icons.Default.Store, + modifier = Modifier.padding(end = rDp(4.dp)) + ) + ActionButton( + onClickListener = { viewModel.showAddServerDialog() }, + icon = Icons.Default.Add, + modifier = Modifier.padding(end = rDp(6.dp)) + ) + } + ) + } + ) { padding -> + Box( + modifier = Modifier + .fillMaxSize() + .padding(padding) + ) { + if (servers.isEmpty()) { + EmptyServersState(onAddServer = { viewModel.showAddServerDialog() }) + } else { + ServersList( + servers = servers, + testingServerId = testingServerId, + onServerClick = { viewModel.showEditServerDialog(it.server) }, + onToggleEnabled = { server, enabled -> + viewModel.toggleServerEnabled(server.server.id, enabled) + }, + onTestConnection = { viewModel.testConnection(it.server) }, + onDeleteServer = { viewModel.deleteServer(it.server.id) } + ) + } + + // Loading overlay + AnimatedVisibility( + visible = isLoading, + enter = fadeIn(), + exit = fadeOut() + ) { + Box( + modifier = Modifier + .fillMaxSize() + .background(MaterialTheme.colorScheme.surface.copy(alpha = 0.8f)), + contentAlignment = Alignment.Center + ) { + CircularProgressIndicator() + } + } + + // Error snackbar + error?.let { errorMessage -> + Snackbar( + modifier = Modifier + .align(Alignment.BottomCenter) + .padding(rDp(16.dp)), + action = { + TextButton(onClick = { viewModel.clearError() }) { + Text("Dismiss") + } + } + ) { + Text(errorMessage) + } + } + } + } + + // Add Server Dialog + if (showAddDialog) { + AddEditServerDialog( + server = null, + isTesting = testingServerId == "new", + testResult = testResult, + onDismiss = { viewModel.hideAddServerDialog() }, + onSave = { name, url, transportType, apiKey, description -> + viewModel.addServer(name, url, transportType, apiKey, description) + }, + onTestConnection = { name, url, transportType, apiKey -> + viewModel.testConnectionWithParams(name, url, transportType, apiKey) + }, + onClearTestResult = { viewModel.clearTestResult() } + ) + } + + // Edit Server Dialog + if (showEditDialog && selectedServer != null) { + AddEditServerDialog( + server = selectedServer, + isTesting = testingServerId == selectedServer?.id, + testResult = testResult, + onDismiss = { viewModel.hideEditServerDialog() }, + onSave = { name, url, transportType, apiKey, description -> + selectedServer?.let { server -> + viewModel.updateServer( + server.copy( + name = name, + url = url, + transportType = transportType, + apiKey = apiKey?.takeIf { it.isNotBlank() }, + description = description + ) + ) + } + }, + onTestConnection = { name, url, transportType, apiKey -> + viewModel.testConnectionWithParams(name, url, transportType, apiKey) + }, + onClearTestResult = { viewModel.clearTestResult() } + ) + } +} + +@Composable +private fun EmptyServersState(onAddServer: () -> Unit) { + Box( + modifier = Modifier.fillMaxSize(), + contentAlignment = Alignment.Center + ) { + Column( + horizontalAlignment = Alignment.CenterHorizontally, + verticalArrangement = Arrangement.spacedBy(rDp(16.dp)) + ) { + Icon( + imageVector = Icons.Default.Cloud, + contentDescription = null, + modifier = Modifier.size(rDp(72.dp)), + tint = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.4f) + ) + Text( + "No MCP Servers", + style = MaterialTheme.typography.titleMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + Text( + "Connect to remote MCP servers to extend\nyour AI capabilities with external tools", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.7f), + modifier = Modifier.padding(horizontal = rDp(32.dp)) + ) + Spacer(modifier = Modifier.height(rDp(8.dp))) + ActionTextButton( + onClickListener = onAddServer, + icon = Icons.Default.Add, + text = "Add Server", + shape = RoundedCornerShape(rDp(12.dp)) + ) + } + } +} + +@Composable +private fun ServersList( + servers: List, + testingServerId: String?, + onServerClick: (McpServerUiState) -> Unit, + onToggleEnabled: (McpServerUiState, Boolean) -> Unit, + onTestConnection: (McpServerUiState) -> Unit, + onDeleteServer: (McpServerUiState) -> Unit +) { + LazyColumn( + modifier = Modifier.fillMaxSize(), + contentPadding = PaddingValues(rDp(16.dp)), + verticalArrangement = Arrangement.spacedBy(rDp(12.dp)) + ) { + // Info card + item { + InfoCard() + } + + items(servers, key = { it.server.id }) { serverState -> + ServerCard( + serverState = serverState, + isTesting = testingServerId == serverState.server.id, + onClick = { onServerClick(serverState) }, + onToggleEnabled = { enabled -> onToggleEnabled(serverState, enabled) }, + onTestConnection = { onTestConnection(serverState) }, + onDelete = { onDeleteServer(serverState) } + ) + } + } +} + +@Composable +private fun InfoCard() { + Surface( + modifier = Modifier.fillMaxWidth(), + shape = RoundedCornerShape(rDp(12.dp)), + color = MaterialTheme.colorScheme.primaryContainer.copy(alpha = 0.3f) + ) { + Row( + modifier = Modifier.padding(rDp(16.dp)), + horizontalArrangement = Arrangement.spacedBy(rDp(12.dp)), + verticalAlignment = Alignment.CenterVertically + ) { + Icon( + imageVector = Icons.Default.Info, + contentDescription = null, + tint = MaterialTheme.colorScheme.primary, + modifier = Modifier.size(rDp(24.dp)) + ) + Column(modifier = Modifier.weight(1f)) { + Text( + text = "MCP (Model Context Protocol)", + style = MaterialTheme.typography.bodyMedium, + fontWeight = FontWeight.SemiBold, + color = MaterialTheme.colorScheme.onSurface + ) + Text( + text = "Connect to remote MCP servers to access external tools, resources, and capabilities for your AI conversations.", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } + } +} + +@Composable +private fun ServerCard( + serverState: McpServerUiState, + isTesting: Boolean, + onClick: () -> Unit, + onToggleEnabled: (Boolean) -> Unit, + onTestConnection: () -> Unit, + onDelete: () -> Unit +) { + val server = serverState.server + val status = serverState.connectionStatus + + Card( + modifier = Modifier + .fillMaxWidth() + .clickable { onClick() }, + shape = RoundedCornerShape(rDp(16.dp)), + colors = CardDefaults.cardColors( + containerColor = when (status) { + McpConnectionStatus.CONNECTED -> MaterialTheme.colorScheme.primaryContainer.copy(alpha = 0.2f) + McpConnectionStatus.ERROR -> MaterialTheme.colorScheme.errorContainer.copy(alpha = 0.2f) + McpConnectionStatus.CONNECTING -> MaterialTheme.colorScheme.tertiaryContainer.copy(alpha = 0.2f) + else -> MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.5f) + } + ) + ) { + Column( + modifier = Modifier + .fillMaxWidth() + .padding(rDp(16.dp)) + ) { + // Header row + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.SpaceBetween, + verticalAlignment = Alignment.CenterVertically + ) { + Row( + verticalAlignment = Alignment.CenterVertically, + modifier = Modifier.weight(1f) + ) { + // Status indicator + StatusIndicator(status = status, isTesting = isTesting) + + Spacer(modifier = Modifier.width(rDp(12.dp))) + + Column { + Text( + text = server.name, + style = MaterialTheme.typography.bodyLarge, + fontWeight = FontWeight.SemiBold, + maxLines = 1, + overflow = TextOverflow.Ellipsis + ) + Text( + text = server.url, + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant, + maxLines = 1, + overflow = TextOverflow.Ellipsis + ) + } + } + + CuteSwitch( + checked = server.isEnabled, + onCheckedChange = onToggleEnabled + ) + } + + // Transport type badge + Spacer(modifier = Modifier.height(rDp(12.dp))) + Row( + horizontalArrangement = Arrangement.spacedBy(rDp(8.dp)), + verticalAlignment = Alignment.CenterVertically + ) { + TransportBadge(transportType = server.transportType) + + if (server.apiKey != null) { + Badge( + containerColor = MaterialTheme.colorScheme.secondary.copy(alpha = 0.1f), + contentColor = MaterialTheme.colorScheme.secondary + ) { + Row( + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.spacedBy(rDp(4.dp)), + modifier = Modifier.padding(horizontal = rDp(4.dp)) + ) { + Icon( + Icons.Default.Key, + contentDescription = null, + modifier = Modifier.size(rDp(12.dp)) + ) + Text("Auth", style = MaterialTheme.typography.labelSmall) + } + } + } + + server.lastConnectedAt?.let { lastConnected -> + Text( + text = "Last connected: ${formatTimestamp(lastConnected)}", + style = MaterialTheme.typography.labelSmall, + color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.7f) + ) + } + } + + // Description + if (server.description.isNotBlank()) { + Spacer(modifier = Modifier.height(rDp(8.dp))) + Text( + text = server.description, + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant, + maxLines = 2, + overflow = TextOverflow.Ellipsis + ) + } + + // Actions + Spacer(modifier = Modifier.height(rDp(12.dp))) + HorizontalDivider(color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.1f)) + Spacer(modifier = Modifier.height(rDp(12.dp))) + + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.SpaceBetween, + verticalAlignment = Alignment.CenterVertically + ) { + ActionTextButton( + onClickListener = onTestConnection, + icon = Icons.Default.Refresh, + text = if (isTesting) "Testing..." else "Test Connection", + shape = RoundedCornerShape(rDp(12.dp)) + ) + + IconButton( + onClick = onDelete, + modifier = Modifier.size(rDp(36.dp)) + ) { + Icon( + Icons.Default.Delete, + contentDescription = "Delete", + tint = MaterialTheme.colorScheme.error, + modifier = Modifier.size(rDp(20.dp)) + ) + } + } + } + } +} + +@Composable +private fun StatusIndicator(status: McpConnectionStatus, isTesting: Boolean) { + val color = when { + isTesting -> MaterialTheme.colorScheme.tertiary + status == McpConnectionStatus.CONNECTED -> SuccessGreen + status == McpConnectionStatus.ERROR -> MaterialTheme.colorScheme.error + status == McpConnectionStatus.CONNECTING -> MaterialTheme.colorScheme.tertiary + else -> MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.3f) + } + + val infiniteTransition = rememberInfiniteTransition(label = "pulse") + val alpha by infiniteTransition.animateFloat( + initialValue = 1f, + targetValue = 0.3f, + animationSpec = infiniteRepeatable( + animation = tween(1000), + repeatMode = RepeatMode.Reverse + ), + label = "pulseAlpha" + ) + + Box( + modifier = Modifier + .size(rDp(12.dp)) + .clip(CircleShape) + .background( + if (isTesting || status == McpConnectionStatus.CONNECTING) { + color.copy(alpha = alpha) + } else { + color + } + ) + ) +} + +@Composable +private fun TransportBadge(transportType: McpTransportType) { + Badge( + containerColor = MaterialTheme.colorScheme.primary.copy(alpha = 0.1f), + contentColor = MaterialTheme.colorScheme.primary + ) { + Text( + text = when (transportType) { + McpTransportType.SSE -> "SSE" + McpTransportType.STREAMABLE_HTTP -> "HTTP" + }, + style = MaterialTheme.typography.labelSmall, + modifier = Modifier.padding(horizontal = rDp(4.dp)) + ) + } +} + +@OptIn(ExperimentalMaterial3Api::class) +@Composable +private fun AddEditServerDialog( + server: McpServer?, + isTesting: Boolean, + testResult: McpTestResult?, + onDismiss: () -> Unit, + onSave: (name: String, url: String, transportType: McpTransportType, apiKey: String?, description: String) -> Unit, + onTestConnection: (name: String, url: String, transportType: McpTransportType, apiKey: String?) -> Unit, + onClearTestResult: () -> Unit +) { + var name by remember { mutableStateOf(server?.name ?: "") } + var url by remember { mutableStateOf(server?.url ?: "") } + var transportType by remember { mutableStateOf(server?.transportType ?: McpTransportType.SSE) } + var apiKey by remember { mutableStateOf(server?.apiKey ?: "") } + var description by remember { mutableStateOf(server?.description ?: "") } + var showApiKey by remember { mutableStateOf(false) } + + val isValid = name.isNotBlank() && url.isNotBlank() && + (url.startsWith("http://") || url.startsWith("https://")) + + ModalBottomSheet( + onDismissRequest = onDismiss, + sheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true), + containerColor = MaterialTheme.colorScheme.surface, + dragHandle = { + Box( + Modifier + .padding(vertical = rDp(12.dp)) + .width(rDp(40.dp)) + .height(rDp(4.dp)) + .clip(RoundedCornerShape(rDp(2.dp))) + .background(MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.4f)) + ) + } + ) { + Column( + modifier = Modifier + .fillMaxWidth() + .padding(horizontal = rDp(24.dp)) + .padding(bottom = rDp(32.dp)) + ) { + // Header + Text( + text = if (server == null) "Add MCP Server" else "Edit MCP Server", + style = MaterialTheme.typography.titleLarge, + fontWeight = FontWeight.Bold + ) + + Text( + text = "Configure a remote MCP server connection", + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + + Spacer(modifier = Modifier.height(rDp(24.dp))) + + // Name field + OutlinedTextField( + value = name, + onValueChange = { + name = it + onClearTestResult() + }, + label = { Text("Server Name") }, + placeholder = { Text("My MCP Server") }, + singleLine = true, + modifier = Modifier.fillMaxWidth(), + shape = RoundedCornerShape(rDp(12.dp)), + leadingIcon = { + Icon(Icons.Default.Label, contentDescription = null) + } + ) + + Spacer(modifier = Modifier.height(rDp(16.dp))) + + // URL field + val isInsecureUrl = url.startsWith("http://") && !url.startsWith("https://") + val showSecurityWarning = isInsecureUrl && apiKey.isNotBlank() + + OutlinedTextField( + value = url, + onValueChange = { + url = it + onClearTestResult() + }, + label = { Text("Server URL") }, + placeholder = { Text("https://api.example.com/mcp") }, + singleLine = true, + modifier = Modifier.fillMaxWidth(), + shape = RoundedCornerShape(rDp(12.dp)), + leadingIcon = { + Icon(Icons.Default.Link, contentDescription = null) + }, + trailingIcon = if (showSecurityWarning) { + { + Icon( + Icons.Default.Warning, + contentDescription = "Security warning", + tint = MaterialTheme.colorScheme.error + ) + } + } else null, + isError = url.isNotBlank() && !url.startsWith("http://") && !url.startsWith("https://"), + supportingText = when { + url.isNotBlank() && !url.startsWith("http://") && !url.startsWith("https://") -> { + { Text("URL must start with http:// or https://") } + } + showSecurityWarning -> { + { + Text( + "Warning: Using HTTP with an API key is insecure. Use HTTPS for secure connections.", + color = MaterialTheme.colorScheme.error + ) + } + } + isInsecureUrl -> { + { + Text( + "Consider using HTTPS for secure connections", + color = MaterialTheme.colorScheme.tertiary + ) + } + } + else -> null + }, + keyboardOptions = KeyboardOptions(keyboardType = KeyboardType.Uri) + ) + + Spacer(modifier = Modifier.height(rDp(16.dp))) + + // Transport type selector + Text( + text = "Transport Type", + style = MaterialTheme.typography.bodyMedium, + fontWeight = FontWeight.Medium + ) + Spacer(modifier = Modifier.height(rDp(8.dp))) + Row( + horizontalArrangement = Arrangement.spacedBy(rDp(8.dp)) + ) { + FilterChip( + selected = transportType == McpTransportType.SSE, + onClick = { + transportType = McpTransportType.SSE + onClearTestResult() + }, + label = { Text("SSE (Server-Sent Events)") }, + leadingIcon = if (transportType == McpTransportType.SSE) { + { Icon(Icons.Default.Check, contentDescription = null, modifier = Modifier.size(rDp(16.dp))) } + } else null + ) + FilterChip( + selected = transportType == McpTransportType.STREAMABLE_HTTP, + onClick = { + transportType = McpTransportType.STREAMABLE_HTTP + onClearTestResult() + }, + label = { Text("Streamable HTTP") }, + leadingIcon = if (transportType == McpTransportType.STREAMABLE_HTTP) { + { Icon(Icons.Default.Check, contentDescription = null, modifier = Modifier.size(rDp(16.dp))) } + } else null + ) + } + + Spacer(modifier = Modifier.height(rDp(16.dp))) + + // API Key field + OutlinedTextField( + value = apiKey, + onValueChange = { + apiKey = it + onClearTestResult() + }, + label = { Text("API Key (Optional)") }, + placeholder = { Text("Bearer token or API key") }, + singleLine = true, + modifier = Modifier.fillMaxWidth(), + shape = RoundedCornerShape(rDp(12.dp)), + leadingIcon = { + Icon(Icons.Default.Key, contentDescription = null) + }, + trailingIcon = { + IconButton(onClick = { showApiKey = !showApiKey }) { + Icon( + if (showApiKey) Icons.Default.VisibilityOff else Icons.Default.Visibility, + contentDescription = if (showApiKey) "Hide" else "Show" + ) + } + }, + visualTransformation = if (showApiKey) VisualTransformation.None else PasswordVisualTransformation() + ) + + Spacer(modifier = Modifier.height(rDp(16.dp))) + + // Description field + OutlinedTextField( + value = description, + onValueChange = { description = it }, + label = { Text("Description (Optional)") }, + placeholder = { Text("What this server provides...") }, + minLines = 2, + maxLines = 3, + modifier = Modifier.fillMaxWidth(), + shape = RoundedCornerShape(rDp(12.dp)) + ) + + Spacer(modifier = Modifier.height(rDp(16.dp))) + + // Test result + AnimatedVisibility( + visible = testResult != null, + enter = fadeIn() + expandVertically(), + exit = fadeOut() + shrinkVertically() + ) { + testResult?.let { result -> + TestResultCard(result = result) + Spacer(modifier = Modifier.height(rDp(16.dp))) + } + } + + // Action buttons + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.spacedBy(rDp(12.dp)) + ) { + OutlinedButton( + onClick = { + onTestConnection(name, url, transportType, apiKey.takeIf { it.isNotBlank() }) + }, + enabled = isValid && !isTesting, + modifier = Modifier.weight(1f), + shape = RoundedCornerShape(rDp(12.dp)) + ) { + if (isTesting) { + CircularProgressIndicator( + modifier = Modifier.size(rDp(16.dp)), + strokeWidth = rDp(2.dp) + ) + Spacer(modifier = Modifier.width(rDp(8.dp))) + } + Text(if (isTesting) "Testing..." else "Test Connection") + } + + Button( + onClick = { + onSave(name, url, transportType, apiKey.takeIf { it.isNotBlank() }, description) + }, + enabled = isValid, + modifier = Modifier.weight(1f), + shape = RoundedCornerShape(rDp(12.dp)) + ) { + Icon(Icons.Default.Save, contentDescription = null) + Spacer(modifier = Modifier.width(rDp(8.dp))) + Text(if (server == null) "Add Server" else "Save Changes") + } + } + } + } +} + +@Composable +private fun TestResultCard(result: McpTestResult) { + Surface( + modifier = Modifier.fillMaxWidth(), + shape = RoundedCornerShape(rDp(12.dp)), + color = if (result.success) { + MaterialTheme.colorScheme.primaryContainer.copy(alpha = 0.3f) + } else { + MaterialTheme.colorScheme.errorContainer.copy(alpha = 0.3f) + } + ) { + Column( + modifier = Modifier.padding(rDp(16.dp)) + ) { + Row( + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.spacedBy(rDp(8.dp)) + ) { + Icon( + imageVector = if (result.success) Icons.Default.CheckCircle else Icons.Default.Error, + contentDescription = null, + tint = if (result.success) { + SuccessGreen + } else { + MaterialTheme.colorScheme.error + }, + modifier = Modifier.size(rDp(20.dp)) + ) + Text( + text = if (result.success) "Connection Successful" else "Connection Failed", + style = MaterialTheme.typography.bodyMedium, + fontWeight = FontWeight.SemiBold, + color = if (result.success) { + SuccessGreen + } else { + MaterialTheme.colorScheme.error + } + ) + } + + if (result.serverInfo != null) { + Spacer(modifier = Modifier.height(rDp(4.dp))) + Text( + text = "Server: ${result.serverInfo}", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + + if (!result.success) { + Spacer(modifier = Modifier.height(rDp(4.dp))) + Text( + text = result.message, + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.error + ) + } + + if (result.tools.isNotEmpty()) { + Spacer(modifier = Modifier.height(rDp(8.dp))) + Text( + text = "Available Tools (${result.tools.size}):", + style = MaterialTheme.typography.labelMedium, + fontWeight = FontWeight.Medium + ) + Spacer(modifier = Modifier.height(rDp(4.dp))) + result.tools.take(5).forEach { tool -> + Text( + text = "• ${tool.name}", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + if (result.tools.size > 5) { + Text( + text = "... and ${result.tools.size - 5} more", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.7f) + ) + } + } + } + } +} + +private fun formatTimestamp(timestamp: Long): String { + val now = System.currentTimeMillis() + val diff = now - timestamp + + return when { + diff < 60_000 -> "just now" + diff < 3600_000 -> "${diff / 60_000}m ago" + diff < 86400_000 -> "${diff / 3600_000}h ago" + diff < 604800_000 -> "${diff / 86400_000}d ago" + else -> { + val sdf = SimpleDateFormat("MMM dd", Locale.getDefault()) + sdf.format(Date(timestamp)) + } + } +} diff --git a/app/src/main/java/com/dark/tool_neuron/ui/screen/McpStoreScreen.kt b/app/src/main/java/com/dark/tool_neuron/ui/screen/McpStoreScreen.kt new file mode 100644 index 00000000..5c1671e2 --- /dev/null +++ b/app/src/main/java/com/dark/tool_neuron/ui/screen/McpStoreScreen.kt @@ -0,0 +1,434 @@ +package com.dark.tool_neuron.ui.screen + +import androidx.compose.animation.AnimatedVisibility +import androidx.compose.animation.fadeIn +import androidx.compose.animation.fadeOut +import androidx.compose.foundation.background +import androidx.compose.foundation.layout.* +import androidx.compose.foundation.lazy.LazyColumn +import androidx.compose.foundation.lazy.LazyRow +import androidx.compose.foundation.lazy.items +import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.* +import androidx.compose.material3.* +import androidx.compose.runtime.* +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.draw.clip +import androidx.compose.ui.graphics.vector.ImageVector +import androidx.compose.ui.text.font.FontWeight +import androidx.compose.ui.text.style.TextOverflow +import androidx.compose.ui.unit.dp +import androidx.hilt.navigation.compose.hiltViewModel +import androidx.lifecycle.compose.collectAsStateWithLifecycle +import com.dark.tool_neuron.models.McpStoreCategories +import com.dark.tool_neuron.models.McpStoreEntry +import com.dark.tool_neuron.ui.components.ActionButton +import com.dark.tool_neuron.ui.components.ActionTextButton +import com.dark.tool_neuron.ui.theme.rDp +import com.dark.tool_neuron.viewmodel.McpStoreViewModel + +@OptIn(ExperimentalMaterial3Api::class) +@Composable +fun McpStoreScreen( + onBackClick: () -> Unit, + viewModel: McpStoreViewModel = hiltViewModel() +) { + val entries by viewModel.filteredEntries.collectAsStateWithLifecycle() + val searchQuery by viewModel.searchQuery.collectAsStateWithLifecycle() + val selectedCategory by viewModel.selectedCategory.collectAsStateWithLifecycle() + val isLoading by viewModel.isLoading.collectAsStateWithLifecycle() + val error by viewModel.error.collectAsStateWithLifecycle() + val installedIds by viewModel.installedIds.collectAsStateWithLifecycle() + val installMessage by viewModel.installMessage.collectAsStateWithLifecycle() + val showTermuxDialog by viewModel.showTermuxDialog.collectAsStateWithLifecycle() + val pendingTermuxEntry by viewModel.pendingTermuxEntry.collectAsStateWithLifecycle() + + Scaffold( + topBar = { + CenterAlignedTopAppBar( + title = { + Text( + "MCP Store", + style = MaterialTheme.typography.titleMedium, + fontWeight = FontWeight.SemiBold + ) + }, + navigationIcon = { + ActionTextButton( + onClickListener = onBackClick, + icon = Icons.Default.ChevronLeft, + text = "Back", + modifier = Modifier.padding(start = rDp(6.dp)) + ) + }, + actions = { + ActionButton( + onClickListener = { viewModel.refresh() }, + icon = Icons.Default.Refresh, + modifier = Modifier.padding(end = rDp(6.dp)) + ) + } + ) + } + ) { padding -> + Column( + modifier = Modifier + .fillMaxSize() + .padding(padding) + ) { + // Search bar + OutlinedTextField( + value = searchQuery, + onValueChange = { viewModel.setSearchQuery(it) }, + modifier = Modifier + .fillMaxWidth() + .padding(horizontal = rDp(16.dp), vertical = rDp(8.dp)), + placeholder = { Text("Search MCP servers...") }, + leadingIcon = { Icon(Icons.Default.Search, contentDescription = "Search") }, + trailingIcon = { + if (searchQuery.isNotEmpty()) { + IconButton(onClick = { viewModel.setSearchQuery("") }) { + Icon(Icons.Default.Close, contentDescription = "Clear") + } + } + }, + singleLine = true, + shape = RoundedCornerShape(rDp(12.dp)) + ) + + // Category chips + LazyRow( + modifier = Modifier + .fillMaxWidth() + .padding(horizontal = rDp(12.dp)), + horizontalArrangement = Arrangement.spacedBy(rDp(8.dp)) + ) { + items(McpStoreCategories.all) { category -> + FilterChip( + selected = selectedCategory == category, + onClick = { viewModel.setSelectedCategory(category) }, + label = { Text(category) }, + leadingIcon = if (selectedCategory == category) { + { Icon(Icons.Default.Check, contentDescription = null, modifier = Modifier.size(16.dp)) } + } else null + ) + } + } + + Spacer(modifier = Modifier.height(rDp(8.dp))) + + // Content + Box(modifier = Modifier.fillMaxSize()) { + if (entries.isEmpty() && !isLoading) { + EmptyStoreState() + } else { + LazyColumn( + modifier = Modifier.fillMaxSize(), + contentPadding = PaddingValues(horizontal = rDp(16.dp), vertical = rDp(8.dp)), + verticalArrangement = Arrangement.spacedBy(rDp(12.dp)) + ) { + items(entries, key = { it.id }) { entry -> + StoreEntryCard( + entry = entry, + isInstalled = entry.id in installedIds, + isTermuxAvailable = viewModel.isTermuxInstalled, + onInstall = { viewModel.installEntry(entry) } + ) + } + } + } + + // Loading overlay + AnimatedVisibility( + visible = isLoading, + enter = fadeIn(), + exit = fadeOut() + ) { + Box( + modifier = Modifier + .fillMaxSize() + .background(MaterialTheme.colorScheme.surface.copy(alpha = 0.8f)), + contentAlignment = Alignment.Center + ) { + CircularProgressIndicator() + } + } + + // Error/install message snackbar + val message = error ?: installMessage + message?.let { msg -> + Snackbar( + modifier = Modifier + .align(Alignment.BottomCenter) + .padding(rDp(16.dp)), + action = { + TextButton(onClick = { + viewModel.clearError() + viewModel.clearInstallMessage() + }) { + Text("Dismiss") + } + } + ) { + Text(msg) + } + } + } + } + } + + // Termux setup dialog + if (showTermuxDialog) { + TermuxSetupDialog( + entry = pendingTermuxEntry, + onDismiss = { viewModel.dismissTermuxDialog() }, + onDownloadTermux = { viewModel.openTermuxDownload(it) }, + onProceed = { viewModel.proceedWithTermuxInstall() }, + isTermuxInstalled = viewModel.isTermuxInstalled + ) + } +} + +@Composable +private fun StoreEntryCard( + entry: McpStoreEntry, + isInstalled: Boolean, + isTermuxAvailable: Boolean, + onInstall: () -> Unit +) { + Card( + modifier = Modifier.fillMaxWidth(), + shape = RoundedCornerShape(rDp(12.dp)) + ) { + Column( + modifier = Modifier.padding(rDp(16.dp)) + ) { + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.SpaceBetween, + verticalAlignment = Alignment.Top + ) { + Row( + modifier = Modifier.weight(1f), + horizontalArrangement = Arrangement.spacedBy(rDp(12.dp)), + verticalAlignment = Alignment.CenterVertically + ) { + // Icon + Box( + modifier = Modifier + .size(rDp(40.dp)) + .clip(RoundedCornerShape(rDp(8.dp))) + .background(MaterialTheme.colorScheme.primaryContainer), + contentAlignment = Alignment.Center + ) { + Icon( + imageVector = getIconForEntry(entry), + contentDescription = null, + tint = MaterialTheme.colorScheme.onPrimaryContainer, + modifier = Modifier.size(rDp(24.dp)) + ) + } + + Column { + Text( + text = entry.name, + style = MaterialTheme.typography.titleSmall, + fontWeight = FontWeight.SemiBold + ) + Text( + text = "by ${entry.author}", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } + + // Install button + if (isInstalled) { + FilledTonalButton( + onClick = {}, + enabled = false + ) { + Icon(Icons.Default.Check, contentDescription = null, modifier = Modifier.size(16.dp)) + Spacer(modifier = Modifier.width(4.dp)) + Text("Added") + } + } else { + Button(onClick = onInstall) { + Icon(Icons.Default.Add, contentDescription = null, modifier = Modifier.size(16.dp)) + Spacer(modifier = Modifier.width(4.dp)) + Text("Install") + } + } + } + + Spacer(modifier = Modifier.height(rDp(8.dp))) + + Text( + text = entry.description, + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant, + maxLines = 2, + overflow = TextOverflow.Ellipsis + ) + + Spacer(modifier = Modifier.height(rDp(8.dp))) + + // Badges row + Row( + horizontalArrangement = Arrangement.spacedBy(rDp(6.dp)) + ) { + CategoryBadge(text = entry.category) + + if (entry.requiresApiKey) { + CategoryBadge(text = "API Key", color = MaterialTheme.colorScheme.tertiaryContainer) + } + + if (entry.requiresTermux) { + CategoryBadge( + text = if (isTermuxAvailable) "Termux" else "Termux Required", + color = if (isTermuxAvailable) + MaterialTheme.colorScheme.secondaryContainer + else + MaterialTheme.colorScheme.errorContainer + ) + } + + Text( + text = entry.transportType, + style = MaterialTheme.typography.labelSmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + + // Setup instructions + entry.setupInstructions?.let { instructions -> + Spacer(modifier = Modifier.height(rDp(6.dp))) + Text( + text = instructions, + style = MaterialTheme.typography.labelSmall, + color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.7f), + maxLines = 1, + overflow = TextOverflow.Ellipsis + ) + } + } + } +} + +@Composable +private fun CategoryBadge( + text: String, + color: androidx.compose.ui.graphics.Color = MaterialTheme.colorScheme.secondaryContainer +) { + Surface( + shape = RoundedCornerShape(rDp(4.dp)), + color = color + ) { + Text( + text = text, + modifier = Modifier.padding(horizontal = rDp(6.dp), vertical = rDp(2.dp)), + style = MaterialTheme.typography.labelSmall + ) + } +} + +@Composable +private fun EmptyStoreState() { + Box( + modifier = Modifier.fillMaxSize(), + contentAlignment = Alignment.Center + ) { + Column(horizontalAlignment = Alignment.CenterHorizontally) { + Icon( + imageVector = Icons.Default.Store, + contentDescription = null, + modifier = Modifier.size(rDp(64.dp)), + tint = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.5f) + ) + Spacer(modifier = Modifier.height(rDp(16.dp))) + Text( + text = "No servers found", + style = MaterialTheme.typography.titleMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + Text( + text = "Try a different search or category", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.7f) + ) + } + } +} + +@Composable +private fun TermuxSetupDialog( + entry: McpStoreEntry?, + onDismiss: () -> Unit, + onDownloadTermux: (android.content.Context) -> Unit, + onProceed: () -> Unit, + isTermuxInstalled: Boolean +) { + val context = androidx.compose.ui.platform.LocalContext.current + + AlertDialog( + onDismissRequest = onDismiss, + icon = { Icon(Icons.Default.Terminal, contentDescription = null) }, + title = { Text("Termux Required") }, + text = { + Column { + if (!isTermuxInstalled) { + Text("This MCP server (${entry?.name ?: "unknown"}) runs locally on your device using Termux.") + Spacer(modifier = Modifier.height(8.dp)) + Text("Termux is a free terminal emulator that lets you run Python and other tools on Android.") + Spacer(modifier = Modifier.height(8.dp)) + Text( + "Please install Termux from GitHub releases or F-Droid (not Play Store).", + style = MaterialTheme.typography.bodySmall, + fontWeight = FontWeight.SemiBold + ) + } else { + Text("Termux is installed. The pip package '${entry?.pipPackage ?: ""}' will be installed in Termux.") + Spacer(modifier = Modifier.height(8.dp)) + Text( + "Make sure Python is installed in Termux (run: pkg install python)", + style = MaterialTheme.typography.bodySmall + ) + } + } + }, + confirmButton = { + if (isTermuxInstalled) { + TextButton(onClick = onProceed) { + Text("Install") + } + } else { + TextButton(onClick = { onDownloadTermux(context) }) { + Text("Download Termux") + } + } + }, + dismissButton = { + TextButton(onClick = onDismiss) { + Text("Cancel") + } + } + ) +} + +private fun getIconForEntry(entry: McpStoreEntry): ImageVector { + return when (entry.iconName) { + "Search" -> Icons.Default.Search + "Code" -> Icons.Default.Code + "Language" -> Icons.Default.Language + "Folder" -> Icons.Default.Folder + "Storage" -> Icons.Default.Storage + "Psychology" -> Icons.Default.Psychology + "Science" -> Icons.Default.Science + "Cloud" -> Icons.Default.Cloud + "OndemandVideo" -> Icons.Default.OndemandVideo + else -> Icons.Default.Extension + } +} diff --git a/app/src/main/java/com/dark/tool_neuron/ui/screen/home_screen/HomeDrawerScreen.kt b/app/src/main/java/com/dark/tool_neuron/ui/screen/home_screen/HomeDrawerScreen.kt index c91a59dc..caac6c31 100644 --- a/app/src/main/java/com/dark/tool_neuron/ui/screen/home_screen/HomeDrawerScreen.kt +++ b/app/src/main/java/com/dark/tool_neuron/ui/screen/home_screen/HomeDrawerScreen.kt @@ -17,6 +17,7 @@ import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.material.icons.Icons import androidx.compose.material.icons.filled.Add import androidx.compose.material.icons.filled.Close +import androidx.compose.material.icons.filled.Cloud import androidx.compose.material.icons.filled.Delete import androidx.compose.material.icons.filled.Psychology import androidx.compose.material3.CircularProgressIndicator @@ -65,6 +66,7 @@ import java.util.Locale fun HomeDrawerScreen( onChatSelected: (String) -> Unit, onVaultManagerClick: () -> Unit, + onMcpServersClick: () -> Unit, chatViewModel: com.dark.tool_neuron.viewmodel.ChatViewModel, viewModel: ChatListViewModel = hiltViewModel() ) { @@ -97,6 +99,11 @@ fun HomeDrawerScreen( }, actions = { Row{ + ActionButton( + onClickListener = onMcpServersClick, + icon = Icons.Filled.Cloud, + modifier = Modifier.padding(end = rDp(6.dp)) + ) ActionButton( onClickListener = onVaultManagerClick, icon = R.drawable.smart_temp_message, @@ -138,7 +145,6 @@ fun HomeDrawerScreen( onChatClick = onChatSelected, onDeleteChat = { chatId -> viewModel.deleteChat(chatId) - // If deleting the currently loaded chat, start a new conversation if (chatId == currentChatId) { chatViewModel.startNewConversation() } @@ -176,7 +182,7 @@ private fun ChatList( scope.launch { isManualRefreshing = true onRefresh() - delay(2000) // Small delay to show the indicator + delay(2000) isManualRefreshing = false } }, @@ -409,4 +415,4 @@ private fun formatTimestamp(timestamp: Long): String { sdf.format(Date(timestamp)) } } -} \ No newline at end of file +} diff --git a/app/src/main/java/com/dark/tool_neuron/ui/screen/home_screen/HomeScreen.kt b/app/src/main/java/com/dark/tool_neuron/ui/screen/home_screen/HomeScreen.kt index 4bb057cc..0c9e3271 100644 --- a/app/src/main/java/com/dark/tool_neuron/ui/screen/home_screen/HomeScreen.kt +++ b/app/src/main/java/com/dark/tool_neuron/ui/screen/home_screen/HomeScreen.kt @@ -112,6 +112,7 @@ fun HomeScreen( onStoreButtonClicked: () -> Unit, onSettingsClick: () -> Unit, onVaultManagerClick: () -> Unit, + onMcpServersClick: () -> Unit, onCharacterClick: () -> Unit = {}, chatViewModel: ChatViewModel, llmModelViewModel: LLMModelViewModel @@ -130,6 +131,12 @@ fun HomeScreen( onVaultManagerClick = { onVaultManagerClick() }, + onMcpServersClick = { + scope.launch { + drawerState.close() + } + onMcpServersClick() + }, onChatSelected = { chatViewModel.loadChat(it) scope.launch { diff --git a/app/src/main/java/com/dark/tool_neuron/viewmodel/ChatViewModel.kt b/app/src/main/java/com/dark/tool_neuron/viewmodel/ChatViewModel.kt index 4bb55852..af53fc40 100644 --- a/app/src/main/java/com/dark/tool_neuron/viewmodel/ChatViewModel.kt +++ b/app/src/main/java/com/dark/tool_neuron/viewmodel/ChatViewModel.kt @@ -25,6 +25,11 @@ import com.dark.tool_neuron.models.plugins.PluginExecutionMetrics import com.dark.tool_neuron.models.plugins.PluginResultData import com.dark.tool_neuron.plugins.PluginExecutionResult import com.dark.tool_neuron.plugins.PluginManager +import com.dark.tool_neuron.plugins.MultiTurnToolResult +import com.dark.tool_neuron.repo.McpServerRepository +import com.dark.tool_neuron.service.McpClientService +import com.dark.tool_neuron.service.McpToolMapper +import com.dark.tool_neuron.service.McpToolReference import com.dark.tool_neuron.state.AppStateManager import com.dark.tool_neuron.worker.ChatManager import com.dark.tool_neuron.worker.DiffusionConfig @@ -65,6 +70,11 @@ class ChatViewModel @Inject constructor( private val appSettings = AppSettingsDataStore(context) private val ttsDataStore = com.dark.tool_neuron.tts.TTSDataStore(context) + // MCP server integration + private val mcpServerRepository: McpServerRepository get() = AppContainer.getMcpServerRepository() + private val mcpClientService: McpClientService get() = AppContainer.getMcpClientService() + private var mcpToolRegistry: Map = emptyMap() + val streamingEnabled: StateFlow = appSettings.streamingEnabled .stateIn(viewModelScope, SharingStarted.Eagerly, true) @@ -372,7 +382,7 @@ class ChatViewModel @Inject constructor( val maxTokens = getCurrentModelMaxTokens() val isNewChat = isNewConversation - val hasTools = PluginManager.hasEnabledTools() + val hasTools = (PluginManager.hasEnabledTools() || mcpToolRegistry.isNotEmpty()) && PluginManager.isToolCallingModelLoaded.value val ragContext = _currentRagContext.value @@ -452,7 +462,7 @@ class ChatViewModel @Inject constructor( chatManager.deleteMessage(lastAssistantMsg.msgId) } val maxTokens = getCurrentModelMaxTokens() - val hasTools = PluginManager.hasEnabledTools() + val hasTools = (PluginManager.hasEnabledTools() || mcpToolRegistry.isNotEmpty()) && PluginManager.isToolCallingModelLoaded.value val ragContext = _currentRagContext.value @@ -494,6 +504,9 @@ class ChatViewModel @Inject constructor( ) { val fullPrompt = ragContext?.let { "$it\n\n$prompt" } ?: prompt + // Sync MCP tools before planning + syncMcpTools() + // Phase 1: Plan _agentPhase.value = AgentPhase.Planning AppStateManager.setGeneratingText() @@ -580,7 +593,8 @@ class ChatViewModel @Inject constructor( _maxToolChainRounds.value = maxRounds val toolSignatures = PluginManager.getToolSignaturesText() - val enabledNames = PluginManager.getEnabledToolNames().map { it.lowercase() } + val enabledNames = PluginManager.getEnabledToolNames().map { it.lowercase() } + + mcpToolRegistry.keys.map { it.lowercase() } val truncatedPlan = plan.take(200) for (round in 1..maxRounds) { @@ -675,8 +689,13 @@ class ChatViewModel @Inject constructor( _currentToolName.value = normalizedName AppStateManager.setExecutingPlugin("", normalizedName) - val toolCall = ToolCall(name = normalizedName, arguments = argsObj) - val result = PluginManager.executeToolForMultiTurn(toolCall) + // Execute via MCP if it's an MCP tool, otherwise use PluginManager + val result = if (isMcpTool(normalizedName)) { + executeMcpToolCall(normalizedName, argsObj.toString()) + } else { + val toolCall = ToolCall(name = normalizedName, arguments = argsObj) + PluginManager.executeToolForMultiTurn(toolCall) + } val isSuccess = !result.isError if (isSuccess) { @@ -1083,7 +1102,8 @@ class ChatViewModel @Inject constructor( // Fallback: parse text if no ToolCall events were received if (toolCalls.isEmpty() && text.isNotBlank()) { Log.d(TAG, "No ToolCall events, trying text parsing fallback") - val enabledNames = PluginManager.getEnabledToolNames().map { it.lowercase() } + val enabledNames = PluginManager.getEnabledToolNames().map { it.lowercase() } + + mcpToolRegistry.keys.map { it.lowercase() } parseToolCallsFromText(text)?.let { parsed -> // Filter against enabled tools to reject hallucinated names val valid = parsed.filter { (name, _) -> @@ -1938,6 +1958,94 @@ class ChatViewModel @Inject constructor( _showModelList.value = false } + // ==================== MCP Tool Integration ==================== + + /** + * Sync MCP tools from enabled servers. Builds a tool registry mapping + * sanitized tool names to their MCP server + original name. + */ + private suspend fun syncMcpTools() { + try { + val servers = mcpServerRepository.getAllServersSnapshot() + .filter { it.isEnabled } + if (servers.isEmpty()) { + mcpToolRegistry = emptyMap() + return + } + + val serverTools = mutableMapOf>() + for (server in servers) { + try { + val tools = mcpClientService.listTools(server) + if (tools.isNotEmpty()) { + serverTools[server] = tools + } + } catch (e: Exception) { + Log.w(TAG, "Failed to list tools from MCP server '${server.name}': ${e.message}") + } + } + + if (serverTools.isEmpty()) { + mcpToolRegistry = emptyMap() + return + } + + val mapping = McpToolMapper.buildMapping(serverTools) + mcpToolRegistry = mapping.toolRegistry + Log.d(TAG, "Synced ${mcpToolRegistry.size} MCP tools from ${serverTools.size} servers") + } catch (e: Exception) { + Log.e(TAG, "Failed to sync MCP tools", e) + mcpToolRegistry = emptyMap() + } + } + + /** + * Execute an MCP tool call via the appropriate MCP server. + * Returns a MultiTurnToolResult compatible with the agent loop. + */ + private suspend fun executeMcpToolCall( + toolName: String, + argsJson: String + ): MultiTurnToolResult { + val ref = mcpToolRegistry[toolName] + ?: return MultiTurnToolResult( + toolName = toolName, + resultJson = "MCP tool not found: $toolName", + isError = true, + pluginName = "MCP", + executionTimeMs = 0 + ) + + val startTime = System.currentTimeMillis() + return try { + val result = mcpClientService.callTool(ref.server, ref.toolName, argsJson) + MultiTurnToolResult( + toolName = toolName, + resultJson = result, + isError = false, + pluginName = "MCP:${ref.server.name}", + executionTimeMs = System.currentTimeMillis() - startTime, + rawData = result + ) + } catch (e: Exception) { + MultiTurnToolResult( + toolName = toolName, + resultJson = "MCP tool execution failed: ${e.message}", + isError = true, + pluginName = "MCP:${ref.server.name}", + executionTimeMs = System.currentTimeMillis() - startTime + ) + } + } + + /** + * Check if a tool name belongs to an MCP server. + */ + fun isMcpTool(toolName: String): Boolean { + return mcpToolRegistry.containsKey(toolName) || + mcpToolRegistry.containsKey(toolName.lowercase()) + } + companion object { private const val TAG = "ChatViewModel" } diff --git a/app/src/main/java/com/dark/tool_neuron/viewmodel/McpServerViewModel.kt b/app/src/main/java/com/dark/tool_neuron/viewmodel/McpServerViewModel.kt new file mode 100644 index 00000000..64688ef3 --- /dev/null +++ b/app/src/main/java/com/dark/tool_neuron/viewmodel/McpServerViewModel.kt @@ -0,0 +1,288 @@ +package com.dark.tool_neuron.viewmodel + +import androidx.lifecycle.ViewModel +import androidx.lifecycle.viewModelScope +import com.dark.tool_neuron.models.table_schema.McpConnectionStatus +import com.dark.tool_neuron.models.table_schema.McpServer +import com.dark.tool_neuron.models.table_schema.McpTransportType +import com.dark.tool_neuron.repo.McpServerRepository +import com.dark.tool_neuron.service.McpClientService +import com.dark.tool_neuron.service.McpTestResult +import dagger.hilt.android.lifecycle.HiltViewModel +import kotlinx.coroutines.delay +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.SharingStarted +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.asStateFlow +import kotlinx.coroutines.flow.combine +import kotlinx.coroutines.flow.stateIn +import kotlinx.coroutines.launch +import javax.inject.Inject + +/** + * UI state for a single MCP server with runtime status + */ +data class McpServerUiState( + val server: McpServer, + val connectionStatus: McpConnectionStatus = McpConnectionStatus.DISCONNECTED +) + +/** + * ViewModel for managing MCP (Model Context Protocol) servers + */ +@HiltViewModel +class McpServerViewModel @Inject constructor( + private val repository: McpServerRepository, + private val mcpClientService: McpClientService +) : ViewModel() { + + companion object { + private const val ERROR_DISPLAY_DURATION_MS = 5000L + } + + // All servers with their runtime status + val servers: StateFlow> = combine( + repository.getAllServers(), + repository.connectionStatuses + ) { servers, statuses -> + servers.map { server -> + McpServerUiState( + server = server, + connectionStatus = statuses[server.id] ?: McpConnectionStatus.DISCONNECTED + ) + } + }.stateIn(viewModelScope, SharingStarted.WhileSubscribed(5000), emptyList()) + + // Server count + val serverCount: StateFlow = repository.getServerCount() + .stateIn(viewModelScope, SharingStarted.WhileSubscribed(5000), 0) + + // Enabled server count + val enabledServerCount: StateFlow = repository.getEnabledServerCount() + .stateIn(viewModelScope, SharingStarted.WhileSubscribed(5000), 0) + + // Currently selected server for editing + private val _selectedServer = MutableStateFlow(null) + val selectedServer: StateFlow = _selectedServer.asStateFlow() + + // Dialog state + private val _showAddDialog = MutableStateFlow(false) + val showAddDialog: StateFlow = _showAddDialog.asStateFlow() + + private val _showEditDialog = MutableStateFlow(false) + val showEditDialog: StateFlow = _showEditDialog.asStateFlow() + + // Test result for the current dialog + private val _testingServerId = MutableStateFlow(null) + val testingServerId: StateFlow = _testingServerId.asStateFlow() + + private val _testResult = MutableStateFlow(null) + val testResult: StateFlow = _testResult.asStateFlow() + + // Loading state + private val _isLoading = MutableStateFlow(false) + val isLoading: StateFlow = _isLoading.asStateFlow() + + // Error state + private val _error = MutableStateFlow(null) + val error: StateFlow = _error.asStateFlow() + private var errorClearJob: kotlinx.coroutines.Job? = null + + /** + * Set an error message that auto-clears after a timeout. + * Cancels any previous auto-clear job to prevent race conditions. + */ + private fun setError(message: String) { + // Cancel any pending error clear job + errorClearJob?.cancel() + + _error.value = message + + // Start new clear job + errorClearJob = viewModelScope.launch { + delay(ERROR_DISPLAY_DURATION_MS) + _error.value = null + } + } + + /** + * Show the add server dialog + */ + fun showAddServerDialog() { + _selectedServer.value = null + _testResult.value = null + _showAddDialog.value = true + } + + /** + * Hide the add server dialog + */ + fun hideAddServerDialog() { + _showAddDialog.value = false + _testResult.value = null + } + + /** + * Show the edit server dialog + */ + fun showEditServerDialog(server: McpServer) { + _selectedServer.value = server + _testResult.value = null + _showEditDialog.value = true + } + + /** + * Hide the edit server dialog + */ + fun hideEditServerDialog() { + _showEditDialog.value = false + _selectedServer.value = null + _testResult.value = null + } + + /** + * Add a new MCP server + */ + fun addServer( + name: String, + url: String, + transportType: McpTransportType = McpTransportType.SSE, + apiKey: String? = null, + description: String = "" + ) { + viewModelScope.launch { + try { + _isLoading.value = true + repository.addServer(name, url, transportType, apiKey, description) + hideAddServerDialog() + } catch (e: Exception) { + setError("Failed to add server: ${e.message}") + } finally { + _isLoading.value = false + } + } + } + + /** + * Update an existing MCP server + */ + fun updateServer(server: McpServer) { + viewModelScope.launch { + try { + _isLoading.value = true + repository.updateServer(server) + hideEditServerDialog() + } catch (e: Exception) { + setError("Failed to update server: ${e.message}") + } finally { + _isLoading.value = false + } + } + } + + /** + * Delete an MCP server + */ + fun deleteServer(serverId: String) { + viewModelScope.launch { + try { + repository.deleteServer(serverId) + } catch (e: Exception) { + setError("Failed to delete server: ${e.message}") + } + } + } + + /** + * Toggle server enabled state + */ + fun toggleServerEnabled(serverId: String, enabled: Boolean) { + viewModelScope.launch { + try { + repository.setServerEnabled(serverId, enabled) + } catch (e: Exception) { + setError("Failed to update server: ${e.message}") + } + } + } + + /** + * Test connection to a server + */ + fun testConnection(server: McpServer) { + viewModelScope.launch { + try { + _testingServerId.value = server.id + _testResult.value = null + repository.updateConnectionStatus(server.id, McpConnectionStatus.CONNECTING) + + val result = mcpClientService.testConnection(server) + _testResult.value = result + + if (result.success) { + repository.updateConnectionStatus(server.id, McpConnectionStatus.CONNECTED) + repository.updateLastConnected(server.id) + } else { + repository.updateConnectionStatus(server.id, McpConnectionStatus.ERROR, result.message) + } + } catch (e: Exception) { + _testResult.value = McpTestResult( + success = false, + message = "Test failed: ${e.message}" + ) + repository.updateConnectionStatus(server.id, McpConnectionStatus.ERROR, e.message) + } finally { + _testingServerId.value = null + } + } + } + + /** + * Test connection with provided parameters (for add/edit dialog) + */ + fun testConnectionWithParams( + name: String, + url: String, + transportType: McpTransportType, + apiKey: String? + ) { + viewModelScope.launch { + try { + _testingServerId.value = "new" + _testResult.value = null + + val tempServer = McpServer( + id = "test", + name = name, + url = url, + transportType = transportType, + apiKey = apiKey?.takeIf { it.isNotBlank() } + ) + + val result = mcpClientService.testConnection(tempServer) + _testResult.value = result + } catch (e: Exception) { + _testResult.value = McpTestResult( + success = false, + message = "Test failed: ${e.message}" + ) + } finally { + _testingServerId.value = null + } + } + } + + /** + * Clear error message + */ + fun clearError() { + _error.value = null + } + + /** + * Clear test result + */ + fun clearTestResult() { + _testResult.value = null + } +} diff --git a/app/src/main/java/com/dark/tool_neuron/viewmodel/McpStoreViewModel.kt b/app/src/main/java/com/dark/tool_neuron/viewmodel/McpStoreViewModel.kt new file mode 100644 index 00000000..4f55dc21 --- /dev/null +++ b/app/src/main/java/com/dark/tool_neuron/viewmodel/McpStoreViewModel.kt @@ -0,0 +1,177 @@ +package com.dark.tool_neuron.viewmodel + +import android.content.Context +import androidx.lifecycle.ViewModel +import androidx.lifecycle.viewModelScope +import com.dark.tool_neuron.models.McpStoreCategories +import com.dark.tool_neuron.models.McpStoreEntry +import com.dark.tool_neuron.models.table_schema.McpServer +import com.dark.tool_neuron.models.table_schema.McpTransportType +import com.dark.tool_neuron.repo.McpServerRepository +import com.dark.tool_neuron.repo.McpStoreRepository +import com.dark.tool_neuron.service.TermuxBridge +import dagger.hilt.android.lifecycle.HiltViewModel +import dagger.hilt.android.qualifiers.ApplicationContext +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.SharingStarted +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.asStateFlow +import kotlinx.coroutines.flow.combine +import kotlinx.coroutines.flow.stateIn +import kotlinx.coroutines.launch +import javax.inject.Inject + +@HiltViewModel +class McpStoreViewModel @Inject constructor( + @ApplicationContext private val appContext: Context, + private val storeRepository: McpStoreRepository, + private val mcpServerRepository: McpServerRepository +) : ViewModel() { + + private val _searchQuery = MutableStateFlow("") + val searchQuery: StateFlow = _searchQuery.asStateFlow() + + private val _selectedCategory = MutableStateFlow(McpStoreCategories.ALL) + val selectedCategory: StateFlow = _selectedCategory.asStateFlow() + + private val _installedIds = MutableStateFlow>(emptySet()) + + private val _installMessage = MutableStateFlow(null) + val installMessage: StateFlow = _installMessage.asStateFlow() + + private val _showTermuxDialog = MutableStateFlow(false) + val showTermuxDialog: StateFlow = _showTermuxDialog.asStateFlow() + + private val _pendingTermuxEntry = MutableStateFlow(null) + val pendingTermuxEntry: StateFlow = _pendingTermuxEntry.asStateFlow() + + val isLoading = storeRepository.isLoading + val error = storeRepository.error + + val filteredEntries: StateFlow> = combine( + storeRepository.entries, + _searchQuery, + _selectedCategory + ) { entries, query, category -> + storeRepository.filterEntries(entries, category, query) + }.stateIn(viewModelScope, SharingStarted.WhileSubscribed(5000), emptyList()) + + val installedIds: StateFlow> = _installedIds.asStateFlow() + + val isTermuxInstalled: Boolean + get() = TermuxBridge.isTermuxInstalled(appContext) + + init { + viewModelScope.launch { + storeRepository.loadEntries() + refreshInstalledIds() + } + } + + fun setSearchQuery(query: String) { + _searchQuery.value = query + } + + fun setSelectedCategory(category: String) { + _selectedCategory.value = category + } + + fun refresh() { + viewModelScope.launch { + storeRepository.refresh() + refreshInstalledIds() + } + } + + fun clearError() { + storeRepository.clearError() + } + + fun clearInstallMessage() { + _installMessage.value = null + } + + fun dismissTermuxDialog() { + _showTermuxDialog.value = false + _pendingTermuxEntry.value = null + } + + /** + * Install an MCP store entry as a local McpServer configuration. + */ + fun installEntry(entry: McpStoreEntry) { + if (entry.requiresTermux && !TermuxBridge.isTermuxInstalled(appContext)) { + _pendingTermuxEntry.value = entry + _showTermuxDialog.value = true + return + } + + viewModelScope.launch { + try { + val url = if (entry.requiresTermux && entry.defaultPort != null) { + TermuxBridge.getLocalServerUrl(entry.defaultPort) + } else { + entry.url + } + + val transportType = try { + McpTransportType.valueOf(entry.transportType) + } catch (e: Exception) { + McpTransportType.SSE + } + + val server = McpServer( + id = McpServer.generateId(), + name = entry.name, + url = url, + transportType = transportType, + isEnabled = !entry.requiresTermux, + description = entry.description, + isLocal = entry.requiresTermux, + sourceStoreId = entry.id + ) + mcpServerRepository.addServerDirect(server) + _installedIds.value = _installedIds.value + entry.id + + if (entry.requiresTermux && entry.pipPackage != null) { + TermuxBridge.pipInstall(appContext, entry.pipPackage) + _installMessage.value = "Installed ${entry.name}. Installing pip package in Termux..." + } else { + _installMessage.value = "${entry.name} added to your MCP servers" + } + } catch (e: Exception) { + _installMessage.value = "Failed to install ${entry.name}: ${e.message}" + } + } + } + + /** + * Proceed with Termux entry installation after user acknowledges the dialog. + */ + fun proceedWithTermuxInstall() { + val entry = _pendingTermuxEntry.value ?: return + _showTermuxDialog.value = false + _pendingTermuxEntry.value = null + if (TermuxBridge.isTermuxInstalled(appContext)) { + installEntry(entry) + } + } + + fun openTermuxDownload(context: Context) { + try { + val intent = android.content.Intent( + android.content.Intent.ACTION_VIEW, + android.net.Uri.parse(TermuxBridge.GITHUB_URL) + ) + intent.addFlags(android.content.Intent.FLAG_ACTIVITY_NEW_TASK) + context.startActivity(intent) + } catch (e: Exception) { + _installMessage.value = "Could not open browser. Visit ${TermuxBridge.GITHUB_URL}" + } + } + + private suspend fun refreshInstalledIds() { + val servers = mcpServerRepository.getAllServersSnapshot() + _installedIds.value = servers.mapNotNull { it.sourceStoreId }.toSet() + } +} diff --git a/app/src/main/java/com/dark/tool_neuron/viewmodel/factory/ChatViewModelFactory.kt b/app/src/main/java/com/dark/tool_neuron/viewmodel/factory/ChatViewModelFactory.kt index d67874be..e8020606 100644 --- a/app/src/main/java/com/dark/tool_neuron/viewmodel/factory/ChatViewModelFactory.kt +++ b/app/src/main/java/com/dark/tool_neuron/viewmodel/factory/ChatViewModelFactory.kt @@ -19,4 +19,4 @@ class ChatViewModelFactory( } throw IllegalArgumentException("Unknown ViewModel class") } -} \ No newline at end of file +} diff --git a/app/src/test/java/com/dark/tool_neuron/integration/McpServerTest.kt b/app/src/test/java/com/dark/tool_neuron/integration/McpServerTest.kt new file mode 100644 index 00000000..63bb7cef --- /dev/null +++ b/app/src/test/java/com/dark/tool_neuron/integration/McpServerTest.kt @@ -0,0 +1,235 @@ +package com.dark.tool_neuron.integration + +import com.dark.tool_neuron.models.table_schema.McpServer +import com.dark.tool_neuron.models.table_schema.McpTransportType +import com.dark.tool_neuron.service.McpToolInfo +import com.dark.tool_neuron.service.McpToolMapper +import org.json.JSONArray +import org.json.JSONObject +import org.junit.Assert.* +import org.junit.Test + +/** + * Unit tests for MCP server-related functionality. + * These tests validate McpToolMapper functionality, JSON parsing, + * and configuration objects without connecting to real MCP servers. + */ +class McpServerTest { + + // Helper function to parse SSE response format. + // This is a simplified version for tests that extracts JSON from single-event SSE responses. + // The production code in McpClientService.parseSseResponse() handles multiple events and validates JSON. + private fun parseSseData(sseResponse: String): String { + val dataLine = sseResponse.lines().find { it.startsWith("data:") } + ?: return sseResponse + return dataLine.removePrefix("data:").trim() + } + + /** + * Test that McpServer can be created with the correct configuration + * for connecting to Zapier's MCP endpoint. + */ + @Test + fun createZapierMcpServerConfiguration() { + val zapierUrl = "https://mcp.zapier.com/api/v1/connect?token=example-token" + + val server = McpServer( + id = McpServer.generateId(), + name = "Zapier MCP", + url = zapierUrl, + transportType = McpTransportType.SSE, + apiKey = null, // Token is in URL + description = "Zapier MCP integration for Google Docs tools" + ) + + assertNotNull(server.id) + assertEquals("Zapier MCP", server.name) + assertEquals(zapierUrl, server.url) + assertEquals(McpTransportType.SSE, server.transportType) + assertTrue(server.isEnabled) + } + + /** + * Test parsing of MCP initialize response in SSE format using helper function. + */ + @Test + fun parseMcpInitializeResponse() { + val sseResponse = """event: message +data: {"result":{"protocolVersion":"2024-11-05","capabilities":{"tools":{"listChanged":true}},"serverInfo":{"name":"zapier","title":"Zapier MCP","version":"1.0.0"}},"jsonrpc":"2.0","id":1}""" + + // Use helper function to extract JSON from SSE format + val jsonStr = parseSseData(sseResponse) + val json = JSONObject(jsonStr) + + assertEquals("2.0", json.getString("jsonrpc")) + assertEquals(1, json.getInt("id")) + + val result = json.getJSONObject("result") + assertEquals("2024-11-05", result.getString("protocolVersion")) + + val serverInfo = result.getJSONObject("serverInfo") + assertEquals("zapier", serverInfo.getString("name")) + assertEquals("1.0.0", serverInfo.getString("version")) + } + + /** + * Test parsing of MCP tools/list response. + */ + @Test + fun parseMcpToolsListResponse() { + val sseResponse = """event: message +data: {"result":{"tools":[{"name":"google_docs_create_document_from_text","description":"Create a new document from text.","inputSchema":{"type":"object","properties":{"title":{"type":"string"}},"required":[]}}]},"jsonrpc":"2.0","id":2}""" + + // Use helper function to extract JSON from SSE format + val jsonStr = parseSseData(sseResponse) + val json = JSONObject(jsonStr) + + val result = json.getJSONObject("result") + val tools = result.getJSONArray("tools") + + assertEquals(1, tools.length()) + + val tool = tools.getJSONObject(0) + assertEquals("google_docs_create_document_from_text", tool.getString("name")) + assertEquals("Create a new document from text.", tool.getString("description")) + + val inputSchema = tool.getJSONObject("inputSchema") + assertEquals("object", inputSchema.getString("type")) + } + + /** + * Test that McpToolMapper correctly maps Zapier tools to the LLM format. + */ + @Test + fun mapZapierToolsToLlmFormat() { + val server = McpServer( + id = "zapier-1", + name = "Zapier MCP", + url = "https://mcp.zapier.com/api/v1/connect", + transportType = McpTransportType.SSE + ) + + val tools = listOf( + McpToolInfo( + name = "google_docs_create_document_from_text", + description = "Create a new document from text. Also supports limited HTML.", + inputSchema = """{"type":"object","properties":{"instructions":{"type":"string","description":"Instructions for running this tool"},"title":{"type":"string","description":"Document Name"},"file":{"type":"string","description":"Document Content"}},"required":["instructions"]}""" + ), + McpToolInfo( + name = "google_docs_find_a_document", + description = "Search for a specific document by name.", + inputSchema = """{"type":"object","properties":{"instructions":{"type":"string","description":"Instructions for running this tool"},"title":{"type":"string","description":"Document Name"}},"required":["instructions"]}""" + ) + ) + + val mapping = McpToolMapper.buildMapping(mapOf(server to tools)) + + // Check that tools JSON is valid + val toolsArray = JSONArray(mapping.toolsJson) + assertEquals(2, toolsArray.length()) + + // Check first tool structure + val firstTool = toolsArray.getJSONObject(0) + assertEquals("function", firstTool.getString("type")) + + val function = firstTool.getJSONObject("function") + // Verify exact tool name format: "zapier_mcp_google_docs_create_document_from_text" + assertEquals("zapier_mcp_google_docs_create_document_from_text", function.getString("name")) + assertTrue(function.has("description")) + + // Check tool registry size and contents + assertEquals(2, mapping.toolRegistry.size) + + // Verify exact tool name mapping in registry + val toolNames = mapping.toolRegistry.values.map { it.toolName }.toSet() + assertEquals( + setOf("google_docs_create_document_from_text", "google_docs_find_a_document"), + toolNames + ) + + // Verify all entries reference the same server + mapping.toolRegistry.values.forEach { entry -> + assertEquals(server, entry.server) + } + } + + /** + * Test that tool call request is properly formatted for MCP protocol. + */ + @Test + fun formatMcpToolCallRequest() { + val toolName = "google_docs_create_document_from_text" + val arguments = JSONObject().apply { + put("instructions", "Create a document titled 'Test' with content 'Hello World'") + put("output_hint", "just the document URL") + put("title", "Test Document") + put("file", "Hello World") + } + + // Use fixed ID for deterministic test behavior + val request = JSONObject().apply { + put("jsonrpc", "2.0") + put("id", 123L) + put("method", "tools/call") + put("params", JSONObject().apply { + put("name", toolName) + put("arguments", arguments) + }) + } + + assertEquals("2.0", request.getString("jsonrpc")) + assertEquals(123L, request.getLong("id")) + assertEquals("tools/call", request.getString("method")) + + val params = request.getJSONObject("params") + assertEquals(toolName, params.getString("name")) + + val args = params.getJSONObject("arguments") + assertEquals("Test Document", args.getString("title")) + assertEquals("Hello World", args.getString("file")) + } + + /** + * Test that both transport types can be assigned to McpServer. + */ + @Test + fun verifyTransportTypeAssignment() { + // SSE transport type + val sseServer = McpServer( + id = "server-sse", + name = "SSE Server", + url = "https://mcp.example.com/sse", + transportType = McpTransportType.SSE + ) + assertEquals(McpTransportType.SSE, sseServer.transportType) + + // Streamable HTTP transport type + val httpServer = McpServer( + id = "server-http", + name = "HTTP Server", + url = "https://mcp.example.com/http", + transportType = McpTransportType.STREAMABLE_HTTP + ) + assertEquals(McpTransportType.STREAMABLE_HTTP, httpServer.transportType) + } + + /** + * Test that server ID generation produces unique UUIDs. + */ + @Test + fun generateUniqueServerIds() { + val ids = mutableSetOf() + // Generate 10 IDs to demonstrate uniqueness with reasonable confidence + repeat(10) { + ids.add(McpServer.generateId()) + } + + // All 10 IDs should be unique + assertEquals(10, ids.size) + + // Verify IDs are valid UUID format (lowercase hexadecimal) + ids.forEach { id -> + assertTrue("ID should be a valid UUID format", id.matches(Regex("[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"))) + } + } +} diff --git a/app/src/test/java/com/dark/tool_neuron/service/McpToolMapperTest.kt b/app/src/test/java/com/dark/tool_neuron/service/McpToolMapperTest.kt new file mode 100644 index 00000000..55e09aa7 --- /dev/null +++ b/app/src/test/java/com/dark/tool_neuron/service/McpToolMapperTest.kt @@ -0,0 +1,54 @@ +package com.dark.tool_neuron.service + +import com.dark.tool_neuron.models.table_schema.McpServer +import com.dark.tool_neuron.models.table_schema.McpTransportType +import org.json.JSONArray +import org.junit.Assert.assertEquals +import org.junit.Assert.assertNotNull +import org.junit.Test + +class McpToolMapperTest { + @Test + fun buildMappingCreatesToolRegistry() { + val server = McpServer( + id = "server-1", + name = "Zapier MCP", + url = "https://example.com/mcp", + transportType = McpTransportType.SSE + ) + val tool = McpToolInfo( + name = "send-email", + description = "Send an email", + inputSchema = """{"type":"object","properties":{"to":{"type":"string"}}}""" + ) + + val mapping = McpToolMapper.buildMapping(mapOf(server to listOf(tool))) + val toolsArray = JSONArray(mapping.toolsJson) + + assertEquals(1, toolsArray.length()) + val function = toolsArray.getJSONObject(0).getJSONObject("function") + assertEquals("zapier_mcp_send_email", function.getString("name")) + assertEquals("object", function.getJSONObject("parameters").getString("type")) + + val reference = mapping.toolRegistry["zapier_mcp_send_email"] + assertNotNull(reference) + assertEquals(server, reference?.server) + assertEquals("send-email", reference?.toolName) + } + + @Test + fun sanitizeIdentifierCollapsesConsecutiveSpecialChars() { + assertEquals("my_tool", McpToolMapper.sanitizeIdentifier("My--Tool")) + assertEquals("a_b", McpToolMapper.sanitizeIdentifier("a---b")) + assertEquals("hello_world", McpToolMapper.sanitizeIdentifier(" hello world ")) + assertEquals("test", McpToolMapper.sanitizeIdentifier("---test---")) + assertEquals("a_b_c", McpToolMapper.sanitizeIdentifier("a..b..c")) + } + + @Test + fun sanitizeIdentifierHandlesEdgeCases() { + assertEquals("mcp", McpToolMapper.sanitizeIdentifier("").ifBlank { "mcp" }) + assertEquals("abc123", McpToolMapper.sanitizeIdentifier("ABC123")) + assertEquals("tool_name_v2", McpToolMapper.sanitizeIdentifier("tool-name-v2")) + } +} diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 2b83cc98..d341d9db 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -36,6 +36,7 @@ coil = "3.1.0" androidx-espresso-core = "3.7.0" androidx-junit = "1.3.0" junit = "4.13.2" +org-json = "20240303" [libraries] coil-compose = { module = "io.coil-kt.coil3:coil-compose", version.ref = "coil" } @@ -84,6 +85,7 @@ xz = { group = "org.tukaani", name = "xz", version.ref = "xz" } androidx-espresso-core = { group = "androidx.test.espresso", name = "espresso-core", version.ref = "androidx-espresso-core" } androidx-junit = { group = "androidx.test.ext", name = "junit", version.ref = "androidx-junit" } junit = { group = "junit", name = "junit", version.ref = "junit" } +org-json = { group = "org.json", name = "json", version.ref = "org-json" } [plugins] android-application = { id = "com.android.application", version.ref = "agp" } @@ -92,4 +94,4 @@ google-dagger-hilt = { id = "com.google.dagger.hilt.android", version.ref = "dag kotlin-android = { id = "org.jetbrains.kotlin.android", version.ref = "kotlin" } kotlin-compose = { id = "org.jetbrains.kotlin.plugin.compose", version.ref = "kotlin" } kotlin-ksp = { id = "com.google.devtools.ksp", version.ref = "ksp" } -kotlin-serialization = { id = "org.jetbrains.kotlin.plugin.serialization", version.ref = "kotlin" } \ No newline at end of file +kotlin-serialization = { id = "org.jetbrains.kotlin.plugin.serialization", version.ref = "kotlin" }