Compare commits
6 commits
main
...
local-hugg
Author | SHA1 | Date | |
---|---|---|---|
![]() |
4bfc8954f3 | ||
![]() |
18ca69f07f | ||
![]() |
f90459656f | ||
![]() |
5830628568 | ||
![]() |
f62e693b8f | ||
![]() |
4abdec044f |
11 changed files with 4228 additions and 1365 deletions
4681
Cargo.lock
generated
4681
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -428,7 +428,7 @@ async-pipe = { git = "https://github.com/zed-industries/async-pipe-rs", rev = "8
|
|||
async-recursion = "1.0.0"
|
||||
async-tar = "0.5.0"
|
||||
async-trait = "0.1"
|
||||
async-tungstenite = "0.29.1"
|
||||
async-tungstenite = "0.30.0"
|
||||
async_zip = { version = "0.0.17", features = ["deflate", "deflate64"] }
|
||||
aws-config = { version = "1.6.1", features = ["behavior-version-latest"] }
|
||||
aws-credential-types = { version = "1.2.2", features = [
|
||||
|
@ -515,7 +515,7 @@ objc = "0.2"
|
|||
open = "5.0.0"
|
||||
ordered-float = "2.1.1"
|
||||
palette = { version = "0.7.5", default-features = false, features = ["std"] }
|
||||
parking_lot = "0.12.1"
|
||||
parking_lot = "0.12.4"
|
||||
partial-json-fixer = "0.5.3"
|
||||
parse_int = "0.9"
|
||||
pathdiff = "0.2"
|
||||
|
|
|
@ -19,6 +19,7 @@ doctest = false
|
|||
[dependencies]
|
||||
acp_thread.workspace = true
|
||||
agent-client-protocol.workspace = true
|
||||
agent_settings.workspace = true
|
||||
agentic-coding-protocol.workspace = true
|
||||
anyhow.workspace = true
|
||||
collections.workspace = true
|
||||
|
|
|
@ -3,6 +3,7 @@ use std::path::PathBuf;
|
|||
use crate::claude::tools::{ClaudeTool, EditToolParams, ReadToolParams};
|
||||
use acp_thread::AcpThread;
|
||||
use agent_client_protocol as acp;
|
||||
use agent_settings::AgentSettings;
|
||||
use anyhow::{Context, Result};
|
||||
use collections::HashMap;
|
||||
use context_server::listener::{McpServerTool, ToolResponse};
|
||||
|
@ -13,6 +14,7 @@ use context_server::types::{
|
|||
use gpui::{App, AsyncApp, Task, WeakEntity};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::Settings;
|
||||
|
||||
pub struct ClaudeZedMcpServer {
|
||||
server: context_server::listener::McpServer,
|
||||
|
@ -114,6 +116,7 @@ pub struct PermissionToolParams {
|
|||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[cfg_attr(test, derive(serde::Deserialize))]
|
||||
pub struct PermissionToolResponse {
|
||||
behavior: PermissionToolBehavior,
|
||||
updated_input: serde_json::Value,
|
||||
|
@ -121,7 +124,8 @@ pub struct PermissionToolResponse {
|
|||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum PermissionToolBehavior {
|
||||
#[cfg_attr(test, derive(serde::Deserialize))]
|
||||
pub enum PermissionToolBehavior {
|
||||
Allow,
|
||||
Deny,
|
||||
}
|
||||
|
@ -141,6 +145,26 @@ impl McpServerTool for PermissionTool {
|
|||
input: Self::Input,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<ToolResponse<Self::Output>> {
|
||||
// Check if we should automatically allow tool actions
|
||||
let always_allow =
|
||||
cx.update(|cx| AgentSettings::get_global(cx).always_allow_tool_actions)?;
|
||||
|
||||
if always_allow {
|
||||
// If always_allow_tool_actions is true, immediately return Allow without prompting
|
||||
let response = PermissionToolResponse {
|
||||
behavior: PermissionToolBehavior::Allow,
|
||||
updated_input: input.input,
|
||||
};
|
||||
|
||||
return Ok(ToolResponse {
|
||||
content: vec![ToolResponseContent::Text {
|
||||
text: serde_json::to_string(&response)?,
|
||||
}],
|
||||
structured_content: (),
|
||||
});
|
||||
}
|
||||
|
||||
// Otherwise, proceed with the normal permission flow
|
||||
let mut thread_rx = self.thread_rx.clone();
|
||||
let Some(thread) = thread_rx.recv().await?.upgrade() else {
|
||||
anyhow::bail!("Thread closed");
|
||||
|
@ -300,3 +324,78 @@ impl McpServerTool for EditTool {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use gpui::TestAppContext;
|
||||
use project::Project;
|
||||
use settings::{Settings, SettingsStore};
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_permission_tool_respects_always_allow_setting(cx: &mut TestAppContext) {
|
||||
// Initialize settings
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
agent_settings::init(cx);
|
||||
});
|
||||
|
||||
// Create a test thread
|
||||
let project = cx.update(|cx| gpui::Entity::new(cx, |_cx| Project::local()));
|
||||
let thread = cx.update(|cx| {
|
||||
gpui::Entity::new(cx, |_cx| {
|
||||
acp_thread::AcpThread::new(
|
||||
acp::ConnectionId("test".into()),
|
||||
project,
|
||||
std::path::Path::new("/tmp"),
|
||||
)
|
||||
})
|
||||
});
|
||||
|
||||
let (tx, rx) = watch::channel(thread.downgrade());
|
||||
let tool = PermissionTool { thread_rx: rx };
|
||||
|
||||
// Test with always_allow_tool_actions = true
|
||||
cx.update(|cx| {
|
||||
AgentSettings::override_global(
|
||||
AgentSettings {
|
||||
always_allow_tool_actions: true,
|
||||
..Default::default()
|
||||
},
|
||||
cx,
|
||||
);
|
||||
});
|
||||
|
||||
let input = PermissionToolParams {
|
||||
tool_name: "test_tool".to_string(),
|
||||
input: serde_json::json!({"test": "data"}),
|
||||
tool_use_id: Some("test_id".to_string()),
|
||||
};
|
||||
|
||||
let result = tool.run(input.clone(), &mut cx.to_async()).await.unwrap();
|
||||
|
||||
// Should return Allow without prompting
|
||||
assert_eq!(result.content.len(), 1);
|
||||
if let ToolResponseContent::Text { text } = &result.content[0] {
|
||||
let response: PermissionToolResponse = serde_json::from_str(text).unwrap();
|
||||
assert!(matches!(response.behavior, PermissionToolBehavior::Allow));
|
||||
} else {
|
||||
panic!("Expected text response");
|
||||
}
|
||||
|
||||
// Test with always_allow_tool_actions = false
|
||||
cx.update(|cx| {
|
||||
AgentSettings::override_global(
|
||||
AgentSettings {
|
||||
always_allow_tool_actions: false,
|
||||
..Default::default()
|
||||
},
|
||||
cx,
|
||||
);
|
||||
});
|
||||
|
||||
// This test would require mocking the permission prompt response
|
||||
// In the real scenario, it would wait for user input
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ use std::{
|
|||
use crate::{AgentServer, AgentServerSettings, AllAgentServersSettings};
|
||||
use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
|
||||
use agent_client_protocol as acp;
|
||||
use agent_settings::AgentSettings;
|
||||
|
||||
use futures::{FutureExt, StreamExt, channel::mpsc, select};
|
||||
use gpui::{Entity, TestAppContext};
|
||||
|
@ -241,6 +242,57 @@ pub async fn test_tool_call_with_confirmation(
|
|||
});
|
||||
}
|
||||
|
||||
pub async fn test_tool_call_always_allow(
|
||||
server: impl AgentServer + 'static,
|
||||
cx: &mut TestAppContext,
|
||||
) {
|
||||
let fs = init_test(cx).await;
|
||||
let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
|
||||
|
||||
// Enable always_allow_tool_actions
|
||||
cx.update(|cx| {
|
||||
AgentSettings::override_global(
|
||||
AgentSettings {
|
||||
always_allow_tool_actions: true,
|
||||
..Default::default()
|
||||
},
|
||||
cx,
|
||||
);
|
||||
});
|
||||
|
||||
let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
|
||||
let full_turn = thread.update(cx, |thread, cx| {
|
||||
thread.send_raw(
|
||||
r#"Run `touch hello.txt && echo "Hello, world!" | tee hello.txt`"#,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
// Wait for the tool call to complete
|
||||
full_turn.await.unwrap();
|
||||
|
||||
thread.read_with(cx, |thread, _cx| {
|
||||
// With always_allow_tool_actions enabled, the tool call should be immediately allowed
|
||||
// without waiting for confirmation
|
||||
let tool_call_entry = thread
|
||||
.entries()
|
||||
.iter()
|
||||
.find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
|
||||
.expect("Expected a tool call entry");
|
||||
|
||||
let AgentThreadEntry::ToolCall(tool_call) = tool_call_entry else {
|
||||
panic!("Expected tool call entry");
|
||||
};
|
||||
|
||||
// Should be allowed, not waiting for confirmation
|
||||
assert!(
|
||||
matches!(tool_call.status, ToolCallStatus::Allowed { .. }),
|
||||
"Expected tool call to be allowed automatically, but got {:?}",
|
||||
tool_call.status
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
|
||||
let fs = init_test(cx).await;
|
||||
|
||||
|
@ -351,6 +403,12 @@ macro_rules! common_e2e_tests {
|
|||
async fn cancel(cx: &mut ::gpui::TestAppContext) {
|
||||
$crate::e2e_tests::test_cancel($server, cx).await;
|
||||
}
|
||||
|
||||
#[::gpui::test]
|
||||
#[cfg_attr(not(feature = "e2e"), ignore)]
|
||||
async fn tool_call_always_allow(cx: &mut ::gpui::TestAppContext) {
|
||||
$crate::e2e_tests::test_tool_call_always_allow($server, cx).await;
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
|
@ -38,6 +38,13 @@ impl AgentModelSelector {
|
|||
move |model, cx| {
|
||||
let provider = model.provider_id().0.to_string();
|
||||
let model_id = model.id().0.to_string();
|
||||
|
||||
// Authenticate the provider when a model is selected
|
||||
let registry = LanguageModelRegistry::read_global(cx);
|
||||
if let Some(provider) = registry.provider(&model.provider_id()) {
|
||||
provider.authenticate(cx).detach();
|
||||
}
|
||||
|
||||
match &model_usage_context {
|
||||
ModelUsageContext::Thread(thread) => {
|
||||
thread.update(cx, |thread, cx| {
|
||||
|
|
|
@ -15,6 +15,7 @@ path = "src/language_models.rs"
|
|||
ai_onboarding.workspace = true
|
||||
anthropic = { workspace = true, features = ["schemars"] }
|
||||
anyhow.workspace = true
|
||||
|
||||
aws-config = { workspace = true, features = ["behavior-version-latest"] }
|
||||
aws-credential-types = { workspace = true, features = [
|
||||
"hardcoded-credentials",
|
||||
|
@ -64,6 +65,7 @@ util.workspace = true
|
|||
workspace-hack.workspace = true
|
||||
zed_llm_client.workspace = true
|
||||
language.workspace = true
|
||||
mistralrs = { git = "https://github.com/EricLBuehler/mistral.rs", tag = "v0.6.0", features = [] }
|
||||
|
||||
[dev-dependencies]
|
||||
editor = { workspace = true, features = ["test-support"] }
|
||||
|
|
|
@ -17,6 +17,7 @@ use crate::provider::cloud::CloudLanguageModelProvider;
|
|||
use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
|
||||
use crate::provider::google::GoogleLanguageModelProvider;
|
||||
use crate::provider::lmstudio::LmStudioLanguageModelProvider;
|
||||
use crate::provider::local::LocalLanguageModelProvider;
|
||||
use crate::provider::mistral::MistralLanguageModelProvider;
|
||||
use crate::provider::ollama::OllamaLanguageModelProvider;
|
||||
use crate::provider::open_ai::OpenAiLanguageModelProvider;
|
||||
|
@ -150,4 +151,8 @@ fn register_language_model_providers(
|
|||
);
|
||||
registry.register_provider(XAiLanguageModelProvider::new(client.http_client(), cx), cx);
|
||||
registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
|
||||
registry.register_provider(
|
||||
LocalLanguageModelProvider::new(client.http_client(), cx),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ pub mod copilot_chat;
|
|||
pub mod deepseek;
|
||||
pub mod google;
|
||||
pub mod lmstudio;
|
||||
pub mod local;
|
||||
pub mod mistral;
|
||||
pub mod ollama;
|
||||
pub mod open_ai;
|
||||
|
|
474
crates/language_models/src/provider/local.rs
Normal file
474
crates/language_models/src/provider/local.rs
Normal file
|
@ -0,0 +1,474 @@
|
|||
use anyhow::{Result, anyhow};
|
||||
use futures::{FutureExt, SinkExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Entity, Task};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
LanguageModelToolChoice, MessageContent, RateLimiter, Role, StopReason,
|
||||
};
|
||||
use mistralrs::{
|
||||
IsqType, Model as MistralModel, Response as MistralResponse, TextMessageRole, TextMessages,
|
||||
TextModelBuilder,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use ui::{ButtonLike, IconName, Indicator, prelude::*};
|
||||
|
||||
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("local");
|
||||
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Local");
|
||||
const DEFAULT_MODEL: &str = "Qwen/Qwen2.5-0.5B-Instruct";
|
||||
|
||||
#[derive(Default, Debug, Clone, PartialEq)]
|
||||
pub struct LocalSettings {
|
||||
pub available_models: Vec<AvailableModel>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct AvailableModel {
|
||||
pub name: String,
|
||||
pub display_name: Option<String>,
|
||||
pub max_tokens: u64,
|
||||
}
|
||||
|
||||
pub struct LocalLanguageModelProvider {
|
||||
state: Entity<State>,
|
||||
}
|
||||
|
||||
pub struct State {
|
||||
model: Option<Arc<MistralModel>>,
|
||||
status: ModelStatus,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
enum ModelStatus {
|
||||
NotLoaded,
|
||||
Loading,
|
||||
Loaded,
|
||||
Error(String),
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn new(_cx: &mut Context<Self>) -> Self {
|
||||
Self {
|
||||
model: None,
|
||||
status: ModelStatus::NotLoaded,
|
||||
}
|
||||
}
|
||||
|
||||
fn is_authenticated(&self) -> bool {
|
||||
// Local models don't require authentication
|
||||
true
|
||||
}
|
||||
|
||||
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
// Skip if already loaded or currently loading
|
||||
if matches!(self.status, ModelStatus::Loaded | ModelStatus::Loading) {
|
||||
return Task::ready(Ok(()));
|
||||
}
|
||||
|
||||
self.status = ModelStatus::Loading;
|
||||
cx.notify();
|
||||
|
||||
let background_executor = cx.background_executor().clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
eprintln!("Local model: Starting to load model");
|
||||
|
||||
// Move the model loading to a background thread
|
||||
let model_result = background_executor
|
||||
.spawn(async move { load_mistral_model().await })
|
||||
.await;
|
||||
|
||||
match model_result {
|
||||
Ok(model) => {
|
||||
eprintln!("Local model: Model loaded successfully");
|
||||
this.update(cx, |state, cx| {
|
||||
state.model = Some(model);
|
||||
state.status = ModelStatus::Loaded;
|
||||
cx.notify();
|
||||
eprintln!("Local model: Status updated to Loaded");
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
let error_msg = e.to_string();
|
||||
eprintln!("Local model: Failed to load model - {}", error_msg);
|
||||
this.update(cx, |state, cx| {
|
||||
state.status = ModelStatus::Error(error_msg.clone());
|
||||
cx.notify();
|
||||
eprintln!("Local model: Status updated to Failed");
|
||||
})?;
|
||||
Err(AuthenticateError::Other(anyhow!(
|
||||
"Failed to load model: {}",
|
||||
error_msg
|
||||
)))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn load_mistral_model() -> Result<Arc<MistralModel>> {
|
||||
println!("\n\n\n\nLoading mistral model...\n\n\n");
|
||||
eprintln!("Starting to load model: {}", DEFAULT_MODEL);
|
||||
|
||||
// Configure the model builder to use background threads for downloads
|
||||
eprintln!("Creating TextModelBuilder...");
|
||||
let builder = TextModelBuilder::new(DEFAULT_MODEL).with_isq(IsqType::Q4K);
|
||||
|
||||
eprintln!("Building model (this should be quick for a 0.5B model)...");
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
match builder.build().await {
|
||||
Ok(model) => {
|
||||
let elapsed = start_time.elapsed();
|
||||
eprintln!("Model loaded successfully in {:?}", elapsed);
|
||||
Ok(Arc::new(model))
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Failed to load model: {:?}", e);
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LocalLanguageModelProvider {
|
||||
pub fn new(_http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
|
||||
let state = cx.new(State::new);
|
||||
Self { state }
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProviderState for LocalLanguageModelProvider {
|
||||
type ObservableEntity = State;
|
||||
|
||||
fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
|
||||
Some(self.state.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProvider for LocalLanguageModelProvider {
|
||||
fn id(&self) -> LanguageModelProviderId {
|
||||
PROVIDER_ID
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelProviderName {
|
||||
PROVIDER_NAME
|
||||
}
|
||||
|
||||
fn icon(&self) -> IconName {
|
||||
IconName::Ai
|
||||
}
|
||||
|
||||
fn provided_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
|
||||
vec![Arc::new(LocalLanguageModel {
|
||||
state: self.state.clone(),
|
||||
request_limiter: RateLimiter::new(4),
|
||||
})]
|
||||
}
|
||||
|
||||
fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
|
||||
self.provided_models(cx).into_iter().next()
|
||||
}
|
||||
|
||||
fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
|
||||
self.default_model(cx)
|
||||
}
|
||||
|
||||
fn is_authenticated(&self, _cx: &App) -> bool {
|
||||
// Local models don't require authentication
|
||||
true
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
|
||||
self.state.update(cx, |state, cx| state.authenticate(cx))
|
||||
}
|
||||
|
||||
fn configuration_view(&self, _window: &mut gpui::Window, cx: &mut App) -> AnyView {
|
||||
cx.new(|_cx| ConfigurationView {
|
||||
state: self.state.clone(),
|
||||
})
|
||||
.into()
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
|
||||
self.state.update(cx, |state, cx| {
|
||||
state.model = None;
|
||||
state.status = ModelStatus::NotLoaded;
|
||||
cx.notify();
|
||||
});
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LocalLanguageModel {
|
||||
state: Entity<State>,
|
||||
request_limiter: RateLimiter,
|
||||
}
|
||||
|
||||
impl LocalLanguageModel {
|
||||
fn to_mistral_messages(&self, request: &LanguageModelRequest) -> TextMessages {
|
||||
let mut messages = TextMessages::new();
|
||||
|
||||
for message in &request.messages {
|
||||
let mut text_content = String::new();
|
||||
|
||||
for content in &message.content {
|
||||
match content {
|
||||
MessageContent::Text(text) => {
|
||||
text_content.push_str(text);
|
||||
}
|
||||
MessageContent::Image { .. } => {
|
||||
// For now, skip image content
|
||||
continue;
|
||||
}
|
||||
MessageContent::ToolResult { .. } => {
|
||||
// Skip tool results for now
|
||||
continue;
|
||||
}
|
||||
MessageContent::Thinking { .. } => {
|
||||
// Skip thinking content
|
||||
continue;
|
||||
}
|
||||
MessageContent::RedactedThinking(_) => {
|
||||
// Skip redacted thinking
|
||||
continue;
|
||||
}
|
||||
MessageContent::ToolUse(_) => {
|
||||
// Skip tool use
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if text_content.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let role = match message.role {
|
||||
Role::User => TextMessageRole::User,
|
||||
Role::Assistant => TextMessageRole::Assistant,
|
||||
Role::System => TextMessageRole::System,
|
||||
};
|
||||
|
||||
messages = messages.add_message(role, text_content);
|
||||
}
|
||||
|
||||
messages
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModel for LocalLanguageModel {
|
||||
fn id(&self) -> LanguageModelId {
|
||||
LanguageModelId(DEFAULT_MODEL.into())
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelName {
|
||||
LanguageModelName(DEFAULT_MODEL.into())
|
||||
}
|
||||
|
||||
fn provider_id(&self) -> LanguageModelProviderId {
|
||||
PROVIDER_ID
|
||||
}
|
||||
|
||||
fn provider_name(&self) -> LanguageModelProviderName {
|
||||
PROVIDER_NAME
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String {
|
||||
format!("local/{}", DEFAULT_MODEL)
|
||||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn supports_images(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn max_token_count(&self) -> u64 {
|
||||
128000 // Qwen2.5 supports 128k context
|
||||
}
|
||||
|
||||
fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
_cx: &App,
|
||||
) -> BoxFuture<'static, Result<u64>> {
|
||||
// Rough estimation: 1 token ≈ 4 characters
|
||||
let mut total_chars = 0;
|
||||
for message in request.messages {
|
||||
for content in message.content {
|
||||
match content {
|
||||
MessageContent::Text(text) => total_chars += text.len(),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
let tokens = (total_chars / 4) as u64;
|
||||
futures::future::ready(Ok(tokens)).boxed()
|
||||
}
|
||||
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<
|
||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||
LanguageModelCompletionError,
|
||||
>,
|
||||
> {
|
||||
let messages = self.to_mistral_messages(&request);
|
||||
let state = self.state.clone();
|
||||
let limiter = self.request_limiter.clone();
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let result: Result<
|
||||
BoxStream<
|
||||
'static,
|
||||
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
|
||||
>,
|
||||
LanguageModelCompletionError,
|
||||
> = limiter
|
||||
.run(async move {
|
||||
let model = cx
|
||||
.read_entity(&state, |state, _| {
|
||||
eprintln!(
|
||||
"Local model: Checking if model is loaded: {:?}",
|
||||
state.status
|
||||
);
|
||||
state.model.clone()
|
||||
})
|
||||
.map_err(|_| {
|
||||
LanguageModelCompletionError::Other(anyhow!("App state dropped"))
|
||||
})?
|
||||
.ok_or_else(|| {
|
||||
eprintln!("Local model: Model is not loaded!");
|
||||
LanguageModelCompletionError::Other(anyhow!("Model not loaded"))
|
||||
})?;
|
||||
|
||||
let (mut tx, rx) = mpsc::channel(32);
|
||||
|
||||
// Spawn a task to handle the stream
|
||||
let _ = smol::spawn(async move {
|
||||
let mut stream = match model.stream_chat_request(messages).await {
|
||||
Ok(stream) => stream,
|
||||
Err(e) => {
|
||||
let _ = tx
|
||||
.send(Err(LanguageModelCompletionError::Other(anyhow!(
|
||||
"Failed to start stream: {}",
|
||||
e
|
||||
))))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
while let Some(response) = stream.next().await {
|
||||
let event = match response {
|
||||
MistralResponse::Chunk(chunk) => {
|
||||
if let Some(choice) = chunk.choices.first() {
|
||||
if let Some(content) = &choice.delta.content {
|
||||
Some(Ok(LanguageModelCompletionEvent::Text(
|
||||
content.clone(),
|
||||
)))
|
||||
} else if let Some(finish_reason) = &choice.finish_reason {
|
||||
let stop_reason = match finish_reason.as_str() {
|
||||
"stop" => StopReason::EndTurn,
|
||||
"length" => StopReason::MaxTokens,
|
||||
_ => StopReason::EndTurn,
|
||||
};
|
||||
Some(Ok(LanguageModelCompletionEvent::Stop(
|
||||
stop_reason,
|
||||
)))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
MistralResponse::Done(_response) => {
|
||||
// For now, we don't emit usage events since the format doesn't match
|
||||
None
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
if let Some(event) = event {
|
||||
if tx.send(event).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
|
||||
Ok(rx.boxed())
|
||||
})
|
||||
.await;
|
||||
|
||||
result
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
struct ConfigurationView {
|
||||
state: Entity<State>,
|
||||
}
|
||||
|
||||
impl Render for ConfigurationView {
|
||||
fn render(&mut self, _window: &mut gpui::Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let status = self.state.read(cx).status.clone();
|
||||
|
||||
div().size_full().child(
|
||||
div()
|
||||
.p_4()
|
||||
.child(
|
||||
div()
|
||||
.flex()
|
||||
.gap_2()
|
||||
.items_center()
|
||||
.child(match &status {
|
||||
ModelStatus::NotLoaded => Label::new("Model not loaded"),
|
||||
ModelStatus::Loading => Label::new("Loading model..."),
|
||||
ModelStatus::Loaded => Label::new("Model loaded"),
|
||||
ModelStatus::Error(e) => Label::new(format!("Error: {}", e)),
|
||||
})
|
||||
.child(match &status {
|
||||
ModelStatus::NotLoaded => Indicator::dot().color(Color::Disabled),
|
||||
ModelStatus::Loading => Indicator::dot().color(Color::Modified),
|
||||
ModelStatus::Loaded => Indicator::dot().color(Color::Success),
|
||||
ModelStatus::Error(_) => Indicator::dot().color(Color::Error),
|
||||
}),
|
||||
)
|
||||
.when(!matches!(status, ModelStatus::Loading), |this| {
|
||||
this.child(
|
||||
ButtonLike::new("load_model")
|
||||
.child(Label::new(if matches!(status, ModelStatus::Loaded) {
|
||||
"Reload Model"
|
||||
} else {
|
||||
"Load Model"
|
||||
}))
|
||||
.on_click(cx.listener(|this, _, _window, cx| {
|
||||
this.state.update(cx, |state, cx| {
|
||||
state.authenticate(cx).detach();
|
||||
});
|
||||
})),
|
||||
)
|
||||
}),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
259
crates/language_models/src/provider/local/tests.rs
Normal file
259
crates/language_models/src/provider/local/tests.rs
Normal file
|
@ -0,0 +1,259 @@
|
|||
use super::*;
|
||||
use gpui::TestAppContext;
|
||||
use http_client::FakeHttpClient;
|
||||
use language_model::{LanguageModelRequest, MessageContent, Role};
|
||||
|
||||
#[gpui::test]
|
||||
fn test_local_provider_creation(cx: &mut TestAppContext) {
|
||||
let http_client = FakeHttpClient::with_200_response();
|
||||
let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
|
||||
|
||||
cx.read(|cx| {
|
||||
assert_eq!(provider.id(), PROVIDER_ID);
|
||||
assert_eq!(provider.name(), PROVIDER_NAME);
|
||||
assert!(!provider.is_authenticated(cx));
|
||||
assert_eq!(provider.provided_models(cx).len(), 1);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_state_initialization(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
let state = cx.new(State::new);
|
||||
|
||||
assert!(!state.read(cx).is_authenticated());
|
||||
assert_eq!(state.read(cx).status, ModelStatus::NotLoaded);
|
||||
assert!(state.read(cx).model.is_none());
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_model_properties(cx: &mut TestAppContext) {
|
||||
let http_client = FakeHttpClient::with_200_response();
|
||||
let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
|
||||
|
||||
// Create a model directly for testing (bypassing authentication)
|
||||
let model = LocalLanguageModel {
|
||||
state: provider.state.clone(),
|
||||
request_limiter: RateLimiter::new(4),
|
||||
};
|
||||
|
||||
assert_eq!(model.id(), LanguageModelId(DEFAULT_MODEL.into()));
|
||||
assert_eq!(model.name(), LanguageModelName(DEFAULT_MODEL.into()));
|
||||
assert_eq!(model.provider_id(), PROVIDER_ID);
|
||||
assert_eq!(model.provider_name(), PROVIDER_NAME);
|
||||
assert_eq!(model.max_token_count(), 128000);
|
||||
assert!(!model.supports_tools());
|
||||
assert!(!model.supports_images());
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_token_counting(cx: &mut TestAppContext) {
|
||||
let http_client = FakeHttpClient::with_200_response();
|
||||
let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
|
||||
|
||||
let model = LocalLanguageModel {
|
||||
state: provider.state.clone(),
|
||||
request_limiter: RateLimiter::new(4),
|
||||
};
|
||||
|
||||
let request = LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
intent: None,
|
||||
mode: None,
|
||||
messages: vec![language_model::LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![MessageContent::Text("Hello, world!".to_string())],
|
||||
cache: false,
|
||||
}],
|
||||
tools: Vec::new(),
|
||||
tool_choice: None,
|
||||
stop: Vec::new(),
|
||||
temperature: None,
|
||||
thinking_allowed: false,
|
||||
};
|
||||
|
||||
let count = cx
|
||||
.update(|cx| model.count_tokens(request, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// "Hello, world!" is 13 characters, so ~3 tokens
|
||||
assert!(count > 0);
|
||||
assert!(count < 10);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_message_conversion(cx: &mut TestAppContext) {
|
||||
let http_client = FakeHttpClient::with_200_response();
|
||||
let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
|
||||
|
||||
let model = LocalLanguageModel {
|
||||
state: provider.state.clone(),
|
||||
request_limiter: RateLimiter::new(4),
|
||||
};
|
||||
|
||||
let request = LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
intent: None,
|
||||
mode: None,
|
||||
messages: vec![
|
||||
language_model::LanguageModelRequestMessage {
|
||||
role: Role::System,
|
||||
content: vec![MessageContent::Text(
|
||||
"You are a helpful assistant.".to_string(),
|
||||
)],
|
||||
cache: false,
|
||||
},
|
||||
language_model::LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![MessageContent::Text("Hello!".to_string())],
|
||||
cache: false,
|
||||
},
|
||||
language_model::LanguageModelRequestMessage {
|
||||
role: Role::Assistant,
|
||||
content: vec![MessageContent::Text("Hi there!".to_string())],
|
||||
cache: false,
|
||||
},
|
||||
],
|
||||
tools: Vec::new(),
|
||||
tool_choice: None,
|
||||
stop: Vec::new(),
|
||||
temperature: None,
|
||||
thinking_allowed: false,
|
||||
};
|
||||
|
||||
let _messages = model.to_mistral_messages(&request);
|
||||
// We can't directly inspect TextMessages, but we can verify it doesn't panic
|
||||
assert!(true); // Placeholder assertion
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_reset_credentials(cx: &mut TestAppContext) {
|
||||
let http_client = FakeHttpClient::with_200_response();
|
||||
let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
|
||||
|
||||
// Simulate loading a model by just setting the status
|
||||
cx.update(|cx| {
|
||||
provider.state.update(cx, |state, cx| {
|
||||
state.status = ModelStatus::Loaded;
|
||||
// We don't actually set a model since we can't mock it safely
|
||||
cx.notify();
|
||||
});
|
||||
});
|
||||
|
||||
cx.read(|cx| {
|
||||
// Since is_authenticated checks for model presence, we need to check status directly
|
||||
assert_eq!(provider.state.read(cx).status, ModelStatus::Loaded);
|
||||
});
|
||||
|
||||
// Reset credentials
|
||||
let task = cx.update(|cx| provider.reset_credentials(cx));
|
||||
task.await.unwrap();
|
||||
|
||||
cx.read(|cx| {
|
||||
assert!(!provider.is_authenticated(cx));
|
||||
assert_eq!(provider.state.read(cx).status, ModelStatus::NotLoaded);
|
||||
assert!(provider.state.read(cx).model.is_none());
|
||||
});
|
||||
}
|
||||
|
||||
// TODO: Fix this test - need to handle window creation in tests
|
||||
// #[gpui::test]
|
||||
// async fn test_configuration_view_rendering(cx: &mut TestAppContext) {
|
||||
// let http_client = FakeHttpClient::with_200_response();
|
||||
// let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
|
||||
|
||||
// let view = cx.update(|cx| provider.configuration_view(cx.window(), cx));
|
||||
|
||||
// // Basic test to ensure the view can be created without panicking
|
||||
// assert!(view.entity_type() == std::any::TypeId::of::<ConfigurationView>());
|
||||
// }
|
||||
|
||||
#[gpui::test]
|
||||
fn test_status_transitions(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
let state = cx.new(State::new);
|
||||
|
||||
// Initial state
|
||||
assert_eq!(state.read(cx).status, ModelStatus::NotLoaded);
|
||||
|
||||
// Transition to loading
|
||||
state.update(cx, |state, cx| {
|
||||
state.status = ModelStatus::Loading;
|
||||
cx.notify();
|
||||
});
|
||||
assert_eq!(state.read(cx).status, ModelStatus::Loading);
|
||||
|
||||
// Transition to loaded
|
||||
state.update(cx, |state, cx| {
|
||||
state.status = ModelStatus::Loaded;
|
||||
cx.notify();
|
||||
});
|
||||
assert_eq!(state.read(cx).status, ModelStatus::Loaded);
|
||||
|
||||
// Transition to error
|
||||
state.update(cx, |state, cx| {
|
||||
state.status = ModelStatus::Error("Test error".to_string());
|
||||
cx.notify();
|
||||
});
|
||||
match &state.read(cx).status {
|
||||
ModelStatus::Error(msg) => assert_eq!(msg, "Test error"),
|
||||
_ => panic!("Expected error status"),
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_provider_shows_models_without_authentication(cx: &mut TestAppContext) {
|
||||
let http_client = FakeHttpClient::with_200_response();
|
||||
let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
|
||||
|
||||
cx.read(|cx| {
|
||||
// Provider should show models even when not authenticated
|
||||
let models = provider.provided_models(cx);
|
||||
assert_eq!(models.len(), 1);
|
||||
|
||||
let model = &models[0];
|
||||
assert_eq!(model.id(), LanguageModelId(DEFAULT_MODEL.into()));
|
||||
assert_eq!(model.name(), LanguageModelName(DEFAULT_MODEL.into()));
|
||||
assert_eq!(model.provider_id(), PROVIDER_ID);
|
||||
assert_eq!(model.provider_name(), PROVIDER_NAME);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_provider_has_icon(cx: &mut TestAppContext) {
|
||||
let http_client = FakeHttpClient::with_200_response();
|
||||
let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
|
||||
|
||||
assert_eq!(provider.icon(), IconName::Ai);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_provider_appears_in_registry(cx: &mut TestAppContext) {
|
||||
use language_model::LanguageModelRegistry;
|
||||
|
||||
cx.update(|cx| {
|
||||
let registry = cx.new(|_| LanguageModelRegistry::default());
|
||||
let http_client = FakeHttpClient::with_200_response();
|
||||
|
||||
// Register the local provider
|
||||
registry.update(cx, |registry, cx| {
|
||||
let provider = LocalLanguageModelProvider::new(Arc::new(http_client), cx);
|
||||
registry.register_provider(provider, cx);
|
||||
});
|
||||
|
||||
// Verify the provider is registered
|
||||
let provider = registry.read(cx).provider(&PROVIDER_ID).unwrap();
|
||||
assert_eq!(provider.name(), PROVIDER_NAME);
|
||||
assert_eq!(provider.icon(), IconName::Ai);
|
||||
|
||||
// Verify it provides models even without authentication
|
||||
let models = provider.provided_models(cx);
|
||||
assert_eq!(models.len(), 1);
|
||||
assert_eq!(models[0].id(), LanguageModelId(DEFAULT_MODEL.into()));
|
||||
});
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue