From 18ca69f07f1c45407fc82b02139cb6adf16b5786 Mon Sep 17 00:00:00 2001 From: Richard Feldman Date: Tue, 29 Jul 2025 21:51:06 -0400 Subject: [PATCH] Get a smaller model working --- crates/agent_ui/src/agent_model_selector.rs | 7 ++ crates/language_models/src/provider/local.rs | 100 ++++++++++++------- 2 files changed, 73 insertions(+), 34 deletions(-) diff --git a/crates/agent_ui/src/agent_model_selector.rs b/crates/agent_ui/src/agent_model_selector.rs index b989e7bf1e..538e692d97 100644 --- a/crates/agent_ui/src/agent_model_selector.rs +++ b/crates/agent_ui/src/agent_model_selector.rs @@ -38,6 +38,13 @@ impl AgentModelSelector { move |model, cx| { let provider = model.provider_id().0.to_string(); let model_id = model.id().0.to_string(); + + // Authenticate the provider when a model is selected + let registry = LanguageModelRegistry::read_global(cx); + if let Some(provider) = registry.provider(&model.provider_id()) { + provider.authenticate(cx).detach(); + } + match &model_usage_context { ModelUsageContext::Thread(thread) => { thread.update(cx, |thread, cx| { diff --git a/crates/language_models/src/provider/local.rs b/crates/language_models/src/provider/local.rs index d0c6e117ee..7a7de5ad3c 100644 --- a/crates/language_models/src/provider/local.rs +++ b/crates/language_models/src/provider/local.rs @@ -18,7 +18,7 @@ use ui::{ButtonLike, IconName, Indicator, prelude::*}; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("local"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Local"); -const DEFAULT_MODEL: &str = "mlx-community/GLM-4.5-Air-3bit"; +const DEFAULT_MODEL: &str = "Qwen/Qwen2.5-0.5B-Instruct"; #[derive(Default, Debug, Clone, PartialEq)] pub struct LocalSettings { @@ -63,36 +63,47 @@ impl State { } fn authenticate(&mut self, cx: &mut Context) -> Task> { - if self.is_authenticated() { - return Task::ready(Ok(())); - } - - if matches!(self.status, ModelStatus::Loading) { + // Skip if already loaded or currently loading + if matches!(self.status, ModelStatus::Loaded | ModelStatus::Loading) { return Task::ready(Ok(())); } self.status = ModelStatus::Loading; cx.notify(); - cx.spawn(async move |this, cx| match load_mistral_model().await { - Ok(model) => { - this.update(cx, |state, cx| { - state.model = Some(model); - state.status = ModelStatus::Loaded; - cx.notify(); - })?; - Ok(()) - } - Err(e) => { - let error_msg = e.to_string(); - this.update(cx, |state, cx| { - state.status = ModelStatus::Error(error_msg.clone()); - cx.notify(); - })?; - Err(AuthenticateError::Other(anyhow!( - "Failed to load model: {}", - error_msg - ))) + let background_executor = cx.background_executor().clone(); + cx.spawn(async move |this, cx| { + eprintln!("Local model: Starting to load model"); + + // Move the model loading to a background thread + let model_result = background_executor + .spawn(async move { load_mistral_model().await }) + .await; + + match model_result { + Ok(model) => { + eprintln!("Local model: Model loaded successfully"); + this.update(cx, |state, cx| { + state.model = Some(model); + state.status = ModelStatus::Loaded; + cx.notify(); + eprintln!("Local model: Status updated to Loaded"); + })?; + Ok(()) + } + Err(e) => { + let error_msg = e.to_string(); + eprintln!("Local model: Failed to load model - {}", error_msg); + this.update(cx, |state, cx| { + state.status = ModelStatus::Error(error_msg.clone()); + cx.notify(); + eprintln!("Local model: Status updated to Failed"); + })?; + Err(AuthenticateError::Other(anyhow!( + "Failed to load model: {}", + error_msg + ))) + } } }) } @@ -100,12 +111,26 @@ impl State { async fn load_mistral_model() -> Result> { println!("\n\n\n\nLoading mistral model...\n\n\n"); - let model = TextModelBuilder::new(DEFAULT_MODEL) - .with_isq(IsqType::Q4_0) - .build() - .await?; + eprintln!("Starting to load model: {}", DEFAULT_MODEL); - Ok(Arc::new(model)) + // Configure the model builder to use background threads for downloads + eprintln!("Creating TextModelBuilder..."); + let builder = TextModelBuilder::new(DEFAULT_MODEL).with_isq(IsqType::Q4K); + + eprintln!("Building model (this should be quick for a 0.5B model)..."); + let start_time = std::time::Instant::now(); + + match builder.build().await { + Ok(model) => { + let elapsed = start_time.elapsed(); + eprintln!("Model loaded successfully in {:?}", elapsed); + Ok(Arc::new(model)) + } + Err(e) => { + eprintln!("Failed to load model: {:?}", e); + Err(e) + } + } } impl LocalLanguageModelProvider { @@ -256,7 +281,7 @@ impl LanguageModel for LocalLanguageModel { } fn supports_tools(&self) -> bool { - false + true } fn supports_images(&self) -> bool { @@ -264,11 +289,11 @@ impl LanguageModel for LocalLanguageModel { } fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool { - false + true } fn max_token_count(&self) -> u64 { - 128000 // GLM-4.5-Air supports 128k context + 128000 // Qwen2.5 supports 128k context } fn count_tokens( @@ -315,11 +340,18 @@ impl LanguageModel for LocalLanguageModel { > = limiter .run(async move { let model = cx - .read_entity(&state, |state, _| state.model.clone()) + .read_entity(&state, |state, _| { + eprintln!( + "Local model: Checking if model is loaded: {:?}", + state.status + ); + state.model.clone() + }) .map_err(|_| { LanguageModelCompletionError::Other(anyhow!("App state dropped")) })? .ok_or_else(|| { + eprintln!("Local model: Model is not loaded!"); LanguageModelCompletionError::Other(anyhow!("Model not loaded")) })?;