Get a smaller model working

This commit is contained in:
Richard Feldman 2025-07-29 21:51:06 -04:00
parent f90459656f
commit 18ca69f07f
No known key found for this signature in database
2 changed files with 73 additions and 34 deletions

View file

@ -38,6 +38,13 @@ impl AgentModelSelector {
move |model, cx| {
let provider = model.provider_id().0.to_string();
let model_id = model.id().0.to_string();
// Authenticate the provider when a model is selected
let registry = LanguageModelRegistry::read_global(cx);
if let Some(provider) = registry.provider(&model.provider_id()) {
provider.authenticate(cx).detach();
}
match &model_usage_context {
ModelUsageContext::Thread(thread) => {
thread.update(cx, |thread, cx| {

View file

@ -18,7 +18,7 @@ use ui::{ButtonLike, IconName, Indicator, prelude::*};
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("local");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Local");
const DEFAULT_MODEL: &str = "mlx-community/GLM-4.5-Air-3bit";
const DEFAULT_MODEL: &str = "Qwen/Qwen2.5-0.5B-Instruct";
#[derive(Default, Debug, Clone, PartialEq)]
pub struct LocalSettings {
@ -63,36 +63,47 @@ impl State {
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
if self.is_authenticated() {
return Task::ready(Ok(()));
}
if matches!(self.status, ModelStatus::Loading) {
// Skip if already loaded or currently loading
if matches!(self.status, ModelStatus::Loaded | ModelStatus::Loading) {
return Task::ready(Ok(()));
}
self.status = ModelStatus::Loading;
cx.notify();
cx.spawn(async move |this, cx| match load_mistral_model().await {
Ok(model) => {
this.update(cx, |state, cx| {
state.model = Some(model);
state.status = ModelStatus::Loaded;
cx.notify();
})?;
Ok(())
}
Err(e) => {
let error_msg = e.to_string();
this.update(cx, |state, cx| {
state.status = ModelStatus::Error(error_msg.clone());
cx.notify();
})?;
Err(AuthenticateError::Other(anyhow!(
"Failed to load model: {}",
error_msg
)))
let background_executor = cx.background_executor().clone();
cx.spawn(async move |this, cx| {
eprintln!("Local model: Starting to load model");
// Move the model loading to a background thread
let model_result = background_executor
.spawn(async move { load_mistral_model().await })
.await;
match model_result {
Ok(model) => {
eprintln!("Local model: Model loaded successfully");
this.update(cx, |state, cx| {
state.model = Some(model);
state.status = ModelStatus::Loaded;
cx.notify();
eprintln!("Local model: Status updated to Loaded");
})?;
Ok(())
}
Err(e) => {
let error_msg = e.to_string();
eprintln!("Local model: Failed to load model - {}", error_msg);
this.update(cx, |state, cx| {
state.status = ModelStatus::Error(error_msg.clone());
cx.notify();
eprintln!("Local model: Status updated to Failed");
})?;
Err(AuthenticateError::Other(anyhow!(
"Failed to load model: {}",
error_msg
)))
}
}
})
}
@ -100,12 +111,26 @@ impl State {
async fn load_mistral_model() -> Result<Arc<MistralModel>> {
println!("\n\n\n\nLoading mistral model...\n\n\n");
let model = TextModelBuilder::new(DEFAULT_MODEL)
.with_isq(IsqType::Q4_0)
.build()
.await?;
eprintln!("Starting to load model: {}", DEFAULT_MODEL);
Ok(Arc::new(model))
// Configure the model builder to use background threads for downloads
eprintln!("Creating TextModelBuilder...");
let builder = TextModelBuilder::new(DEFAULT_MODEL).with_isq(IsqType::Q4K);
eprintln!("Building model (this should be quick for a 0.5B model)...");
let start_time = std::time::Instant::now();
match builder.build().await {
Ok(model) => {
let elapsed = start_time.elapsed();
eprintln!("Model loaded successfully in {:?}", elapsed);
Ok(Arc::new(model))
}
Err(e) => {
eprintln!("Failed to load model: {:?}", e);
Err(e)
}
}
}
impl LocalLanguageModelProvider {
@ -256,7 +281,7 @@ impl LanguageModel for LocalLanguageModel {
}
fn supports_tools(&self) -> bool {
false
true
}
fn supports_images(&self) -> bool {
@ -264,11 +289,11 @@ impl LanguageModel for LocalLanguageModel {
}
fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
false
true
}
fn max_token_count(&self) -> u64 {
128000 // GLM-4.5-Air supports 128k context
128000 // Qwen2.5 supports 128k context
}
fn count_tokens(
@ -315,11 +340,18 @@ impl LanguageModel for LocalLanguageModel {
> = limiter
.run(async move {
let model = cx
.read_entity(&state, |state, _| state.model.clone())
.read_entity(&state, |state, _| {
eprintln!(
"Local model: Checking if model is loaded: {:?}",
state.status
);
state.model.clone()
})
.map_err(|_| {
LanguageModelCompletionError::Other(anyhow!("App state dropped"))
})?
.ok_or_else(|| {
eprintln!("Local model: Model is not loaded!");
LanguageModelCompletionError::Other(anyhow!("Model not loaded"))
})?;