Get a smaller model working
This commit is contained in:
parent
f90459656f
commit
18ca69f07f
2 changed files with 73 additions and 34 deletions
|
@ -38,6 +38,13 @@ impl AgentModelSelector {
|
|||
move |model, cx| {
|
||||
let provider = model.provider_id().0.to_string();
|
||||
let model_id = model.id().0.to_string();
|
||||
|
||||
// Authenticate the provider when a model is selected
|
||||
let registry = LanguageModelRegistry::read_global(cx);
|
||||
if let Some(provider) = registry.provider(&model.provider_id()) {
|
||||
provider.authenticate(cx).detach();
|
||||
}
|
||||
|
||||
match &model_usage_context {
|
||||
ModelUsageContext::Thread(thread) => {
|
||||
thread.update(cx, |thread, cx| {
|
||||
|
|
|
@ -18,7 +18,7 @@ use ui::{ButtonLike, IconName, Indicator, prelude::*};
|
|||
|
||||
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("local");
|
||||
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Local");
|
||||
const DEFAULT_MODEL: &str = "mlx-community/GLM-4.5-Air-3bit";
|
||||
const DEFAULT_MODEL: &str = "Qwen/Qwen2.5-0.5B-Instruct";
|
||||
|
||||
#[derive(Default, Debug, Clone, PartialEq)]
|
||||
pub struct LocalSettings {
|
||||
|
@ -63,36 +63,47 @@ impl State {
|
|||
}
|
||||
|
||||
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
if self.is_authenticated() {
|
||||
return Task::ready(Ok(()));
|
||||
}
|
||||
|
||||
if matches!(self.status, ModelStatus::Loading) {
|
||||
// Skip if already loaded or currently loading
|
||||
if matches!(self.status, ModelStatus::Loaded | ModelStatus::Loading) {
|
||||
return Task::ready(Ok(()));
|
||||
}
|
||||
|
||||
self.status = ModelStatus::Loading;
|
||||
cx.notify();
|
||||
|
||||
cx.spawn(async move |this, cx| match load_mistral_model().await {
|
||||
Ok(model) => {
|
||||
this.update(cx, |state, cx| {
|
||||
state.model = Some(model);
|
||||
state.status = ModelStatus::Loaded;
|
||||
cx.notify();
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
let error_msg = e.to_string();
|
||||
this.update(cx, |state, cx| {
|
||||
state.status = ModelStatus::Error(error_msg.clone());
|
||||
cx.notify();
|
||||
})?;
|
||||
Err(AuthenticateError::Other(anyhow!(
|
||||
"Failed to load model: {}",
|
||||
error_msg
|
||||
)))
|
||||
let background_executor = cx.background_executor().clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
eprintln!("Local model: Starting to load model");
|
||||
|
||||
// Move the model loading to a background thread
|
||||
let model_result = background_executor
|
||||
.spawn(async move { load_mistral_model().await })
|
||||
.await;
|
||||
|
||||
match model_result {
|
||||
Ok(model) => {
|
||||
eprintln!("Local model: Model loaded successfully");
|
||||
this.update(cx, |state, cx| {
|
||||
state.model = Some(model);
|
||||
state.status = ModelStatus::Loaded;
|
||||
cx.notify();
|
||||
eprintln!("Local model: Status updated to Loaded");
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
let error_msg = e.to_string();
|
||||
eprintln!("Local model: Failed to load model - {}", error_msg);
|
||||
this.update(cx, |state, cx| {
|
||||
state.status = ModelStatus::Error(error_msg.clone());
|
||||
cx.notify();
|
||||
eprintln!("Local model: Status updated to Failed");
|
||||
})?;
|
||||
Err(AuthenticateError::Other(anyhow!(
|
||||
"Failed to load model: {}",
|
||||
error_msg
|
||||
)))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -100,12 +111,26 @@ impl State {
|
|||
|
||||
async fn load_mistral_model() -> Result<Arc<MistralModel>> {
|
||||
println!("\n\n\n\nLoading mistral model...\n\n\n");
|
||||
let model = TextModelBuilder::new(DEFAULT_MODEL)
|
||||
.with_isq(IsqType::Q4_0)
|
||||
.build()
|
||||
.await?;
|
||||
eprintln!("Starting to load model: {}", DEFAULT_MODEL);
|
||||
|
||||
Ok(Arc::new(model))
|
||||
// Configure the model builder to use background threads for downloads
|
||||
eprintln!("Creating TextModelBuilder...");
|
||||
let builder = TextModelBuilder::new(DEFAULT_MODEL).with_isq(IsqType::Q4K);
|
||||
|
||||
eprintln!("Building model (this should be quick for a 0.5B model)...");
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
match builder.build().await {
|
||||
Ok(model) => {
|
||||
let elapsed = start_time.elapsed();
|
||||
eprintln!("Model loaded successfully in {:?}", elapsed);
|
||||
Ok(Arc::new(model))
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Failed to load model: {:?}", e);
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LocalLanguageModelProvider {
|
||||
|
@ -256,7 +281,7 @@ impl LanguageModel for LocalLanguageModel {
|
|||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
false
|
||||
true
|
||||
}
|
||||
|
||||
fn supports_images(&self) -> bool {
|
||||
|
@ -264,11 +289,11 @@ impl LanguageModel for LocalLanguageModel {
|
|||
}
|
||||
|
||||
fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
|
||||
false
|
||||
true
|
||||
}
|
||||
|
||||
fn max_token_count(&self) -> u64 {
|
||||
128000 // GLM-4.5-Air supports 128k context
|
||||
128000 // Qwen2.5 supports 128k context
|
||||
}
|
||||
|
||||
fn count_tokens(
|
||||
|
@ -315,11 +340,18 @@ impl LanguageModel for LocalLanguageModel {
|
|||
> = limiter
|
||||
.run(async move {
|
||||
let model = cx
|
||||
.read_entity(&state, |state, _| state.model.clone())
|
||||
.read_entity(&state, |state, _| {
|
||||
eprintln!(
|
||||
"Local model: Checking if model is loaded: {:?}",
|
||||
state.status
|
||||
);
|
||||
state.model.clone()
|
||||
})
|
||||
.map_err(|_| {
|
||||
LanguageModelCompletionError::Other(anyhow!("App state dropped"))
|
||||
})?
|
||||
.ok_or_else(|| {
|
||||
eprintln!("Local model: Model is not loaded!");
|
||||
LanguageModelCompletionError::Other(anyhow!("Model not loaded"))
|
||||
})?;
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue