ZIm/crates/language_models/src/provider/lmstudio.rs
Antonio Scandurra f517050548
Partially fix assistant onboarding (#25313)
While investigating #24896, I noticed two issues:

1. The default configuration for the `zed.dev` provider was using the
wrong string for Claude 3.5 Sonnet. This meant the provider would always
result as not configured until the user selected it from the model
picker, because we couldn't deserialize that string to a valid
`anthropic::Model` enum variant.
2. When clicking on `Open New Chat`/`Start New Thread` in the provider
configuration, we would select `Claude 3.5 Haiku` by default instead of
Claude 3.5 Sonnet.

Release Notes:

- Fixed some issues that caused AI providers to sometimes be
misconfigured.
2025-02-24 07:29:55 +00:00

528 lines
20 KiB
Rust

use anyhow::{anyhow, Result};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
use http_client::HttpClient;
use language_model::{AuthenticateError, LanguageModelCompletionEvent};
use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, RateLimiter, Role,
};
use lmstudio::{
get_models, preload_model, stream_chat_completion, ChatCompletionRequest, ChatMessage,
ModelType,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::{collections::BTreeMap, sync::Arc};
use ui::{prelude::*, ButtonLike, Indicator};
use util::ResultExt;
use crate::AllLanguageModelSettings;
const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download";
const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models";
const LMSTUDIO_SITE: &str = "https://lmstudio.ai/";
const PROVIDER_ID: &str = "lmstudio";
const PROVIDER_NAME: &str = "LM Studio";
#[derive(Default, Debug, Clone, PartialEq)]
pub struct LmStudioSettings {
pub api_url: String,
pub available_models: Vec<AvailableModel>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct AvailableModel {
/// The model name in the LM Studio API. e.g. qwen2.5-coder-7b, phi-4, etc
pub name: String,
/// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
pub display_name: Option<String>,
/// The model's context window size.
pub max_tokens: usize,
}
pub struct LmStudioLanguageModelProvider {
http_client: Arc<dyn HttpClient>,
state: gpui::Entity<State>,
}
pub struct State {
http_client: Arc<dyn HttpClient>,
available_models: Vec<lmstudio::Model>,
fetch_model_task: Option<Task<Result<()>>>,
_subscription: Subscription,
}
impl State {
fn is_authenticated(&self) -> bool {
!self.available_models.is_empty()
}
fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
let http_client = self.http_client.clone();
let api_url = settings.api_url.clone();
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
cx.spawn(|this, mut cx| async move {
let models = get_models(http_client.as_ref(), &api_url, None).await?;
let mut models: Vec<lmstudio::Model> = models
.into_iter()
.filter(|model| model.r#type != ModelType::Embeddings)
.map(|model| lmstudio::Model::new(&model.id, None, None))
.collect();
models.sort_by(|a, b| a.name.cmp(&b.name));
this.update(&mut cx, |this, cx| {
this.available_models = models;
cx.notify();
})
})
}
fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
let task = self.fetch_models(cx);
self.fetch_model_task.replace(task);
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
if self.is_authenticated() {
return Task::ready(Ok(()));
}
let fetch_models_task = self.fetch_models(cx);
cx.spawn(|_this, _cx| async move { Ok(fetch_models_task.await?) })
}
}
impl LmStudioLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
let this = Self {
http_client: http_client.clone(),
state: cx.new(|cx| {
let subscription = cx.observe_global::<SettingsStore>({
let mut settings = AllLanguageModelSettings::get_global(cx).lmstudio.clone();
move |this: &mut State, cx| {
let new_settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
if &settings != new_settings {
settings = new_settings.clone();
this.restart_fetch_models_task(cx);
cx.notify();
}
}
});
State {
http_client,
available_models: Default::default(),
fetch_model_task: None,
_subscription: subscription,
}
}),
};
this.state
.update(cx, |state, cx| state.restart_fetch_models_task(cx));
this
}
}
impl LanguageModelProviderState for LmStudioLanguageModelProvider {
type ObservableEntity = State;
fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
Some(self.state.clone())
}
}
impl LanguageModelProvider for LmStudioLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn icon(&self) -> IconName {
IconName::AiLmStudio
}
fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
self.provided_models(cx).into_iter().next()
}
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
let mut models: BTreeMap<String, lmstudio::Model> = BTreeMap::default();
// Add models from the LM Studio API
for model in self.state.read(cx).available_models.iter() {
models.insert(model.name.clone(), model.clone());
}
// Override with available models from settings
for model in AllLanguageModelSettings::get_global(cx)
.lmstudio
.available_models
.iter()
{
models.insert(
model.name.clone(),
lmstudio::Model {
name: model.name.clone(),
display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
},
);
}
models
.into_values()
.map(|model| {
Arc::new(LmStudioLanguageModel {
id: LanguageModelId::from(model.name.clone()),
model: model.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
})
.collect()
}
fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &App) {
let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
let http_client = self.http_client.clone();
let api_url = settings.api_url.clone();
let id = model.id().0.to_string();
cx.spawn(|_| async move { preload_model(http_client, &api_url, &id).await })
.detach_and_log_err(cx);
}
fn is_authenticated(&self, cx: &App) -> bool {
self.state.read(cx).is_authenticated()
}
fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
self.state.update(cx, |state, cx| state.authenticate(cx))
}
fn configuration_view(&self, _window: &mut Window, cx: &mut App) -> AnyView {
let state = self.state.clone();
cx.new(|cx| ConfigurationView::new(state, cx)).into()
}
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
self.state.update(cx, |state, cx| state.fetch_models(cx))
}
}
pub struct LmStudioLanguageModel {
id: LanguageModelId,
model: lmstudio::Model,
http_client: Arc<dyn HttpClient>,
request_limiter: RateLimiter,
}
impl LmStudioLanguageModel {
fn to_lmstudio_request(&self, request: LanguageModelRequest) -> ChatCompletionRequest {
ChatCompletionRequest {
model: self.model.name.clone(),
messages: request
.messages
.into_iter()
.map(|msg| match msg.role {
Role::User => ChatMessage::User {
content: msg.string_contents(),
},
Role::Assistant => ChatMessage::Assistant {
content: Some(msg.string_contents()),
tool_calls: None,
},
Role::System => ChatMessage::System {
content: msg.string_contents(),
},
})
.collect(),
stream: true,
max_tokens: Some(-1),
stop: Some(request.stop),
temperature: request.temperature.or(Some(0.0)),
tools: vec![],
}
}
}
impl LanguageModel for LmStudioLanguageModel {
fn id(&self) -> LanguageModelId {
self.id.clone()
}
fn name(&self) -> LanguageModelName {
LanguageModelName::from(self.model.display_name().to_string())
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn telemetry_id(&self) -> String {
format!("lmstudio/{}", self.model.id())
}
fn max_token_count(&self) -> usize {
self.model.max_token_count()
}
fn count_tokens(
&self,
request: LanguageModelRequest,
_cx: &App,
) -> BoxFuture<'static, Result<usize>> {
// Endpoint for this is coming soon. In the meantime, hacky estimation
let token_count = request
.messages
.iter()
.map(|msg| msg.string_contents().split_whitespace().count())
.sum::<usize>();
let estimated_tokens = (token_count as f64 * 0.75) as usize;
async move { Ok(estimated_tokens) }.boxed()
}
fn stream_completion(
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
let request = self.to_lmstudio_request(request);
let http_client = self.http_client.clone();
let Ok(api_url) = cx.update(|cx| {
let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
settings.api_url.clone()
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
let future = self.request_limiter.stream(async move {
let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
let stream = response
.filter_map(|response| async move {
match response {
Ok(fragment) => {
// Skip empty deltas
if fragment.choices[0].delta.is_object()
&& fragment.choices[0].delta.as_object().unwrap().is_empty()
{
return None;
}
// Try to parse the delta as ChatMessage
if let Ok(chat_message) = serde_json::from_value::<ChatMessage>(
fragment.choices[0].delta.clone(),
) {
let content = match chat_message {
ChatMessage::User { content } => content,
ChatMessage::Assistant { content, .. } => {
content.unwrap_or_default()
}
ChatMessage::System { content } => content,
};
if !content.is_empty() {
Some(Ok(content))
} else {
None
}
} else {
None
}
}
Err(error) => Some(Err(error)),
}
})
.boxed();
Ok(stream)
});
async move {
Ok(future
.await?
.map(|result| result.map(LanguageModelCompletionEvent::Text))
.boxed())
}
.boxed()
}
fn use_any_tool(
&self,
_request: LanguageModelRequest,
_tool_name: String,
_tool_description: String,
_schema: serde_json::Value,
_cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
async move { Ok(futures::stream::empty().boxed()) }.boxed()
}
}
struct ConfigurationView {
state: gpui::Entity<State>,
loading_models_task: Option<Task<()>>,
}
impl ConfigurationView {
pub fn new(state: gpui::Entity<State>, cx: &mut Context<Self>) -> Self {
let loading_models_task = Some(cx.spawn({
let state = state.clone();
|this, mut cx| async move {
if let Some(task) = state
.update(&mut cx, |state, cx| state.authenticate(cx))
.log_err()
{
task.await.log_err();
}
this.update(&mut cx, |this, cx| {
this.loading_models_task = None;
cx.notify();
})
.log_err();
}
}));
Self {
state,
loading_models_task,
}
}
fn retry_connection(&self, cx: &mut App) {
self.state
.update(cx, |state, cx| state.fetch_models(cx))
.detach_and_log_err(cx);
}
}
impl Render for ConfigurationView {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let is_authenticated = self.state.read(cx).is_authenticated();
let lmstudio_intro = "Run local LLMs like Llama, Phi, and Qwen.";
let lmstudio_reqs =
"To use LM Studio as a provider for Zed assistant, it needs to be running with at least one model downloaded.";
let inline_code_bg = cx.theme().colors().editor_foreground.opacity(0.05);
if self.loading_models_task.is_some() {
div().child(Label::new("Loading models...")).into_any()
} else {
v_flex()
.size_full()
.gap_3()
.child(
v_flex()
.size_full()
.gap_2()
.p_1()
.child(Label::new(lmstudio_intro))
.child(Label::new(lmstudio_reqs))
.child(
h_flex()
.gap_0p5()
.child(Label::new("To get your first model, try running"))
.child(
div()
.bg(inline_code_bg)
.px_1p5()
.rounded_md()
.child(Label::new("lms get qwen2.5-coder-7b")),
),
),
)
.child(
h_flex()
.w_full()
.pt_2()
.justify_between()
.gap_2()
.child(
h_flex()
.w_full()
.gap_2()
.map(|this| {
if is_authenticated {
this.child(
Button::new("lmstudio-site", "LM Studio")
.style(ButtonStyle::Subtle)
.icon(IconName::ArrowUpRight)
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.on_click(move |_, _window, cx| {
cx.open_url(LMSTUDIO_SITE)
})
.into_any_element(),
)
} else {
this.child(
Button::new(
"download_lmstudio_button",
"Download LM Studio",
)
.style(ButtonStyle::Subtle)
.icon(IconName::ArrowUpRight)
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.on_click(move |_, _window, cx| {
cx.open_url(LMSTUDIO_DOWNLOAD_URL)
})
.into_any_element(),
)
}
})
.child(
Button::new("view-models", "Model Catalog")
.style(ButtonStyle::Subtle)
.icon(IconName::ArrowUpRight)
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.on_click(move |_, _window, cx| {
cx.open_url(LMSTUDIO_CATALOG_URL)
}),
),
)
.child(if is_authenticated {
// This is only a button to ensure the spacing is correct
// it should stay disabled
ButtonLike::new("connected")
.disabled(true)
// Since this won't ever be clickable, we can use the arrow cursor
.cursor_style(gpui::CursorStyle::Arrow)
.child(
h_flex()
.gap_2()
.child(Indicator::dot().color(Color::Success))
.child(Label::new("Connected"))
.into_any_element(),
)
.into_any_element()
} else {
Button::new("retry_lmstudio_models", "Connect")
.icon_position(IconPosition::Start)
.icon(IconName::ArrowCircle)
.on_click(cx.listener(move |this, _, _window, cx| {
this.retry_connection(cx)
}))
.into_any_element()
}),
)
.into_any()
}
}
}