diff --git a/crates/agent_settings/src/agent_settings.rs b/crates/agent_settings/src/agent_settings.rs index 9e8fd0c699..1386555582 100644 --- a/crates/agent_settings/src/agent_settings.rs +++ b/crates/agent_settings/src/agent_settings.rs @@ -386,7 +386,9 @@ impl AgentSettingsContent { _ => None, }; settings.provider = Some(AgentProviderContentV1::LmStudio { - default_model: Some(lmstudio::Model::new(&model, None, None, false)), + default_model: Some(lmstudio::Model::new( + &model, None, None, false, false, + )), api_url, }); } diff --git a/crates/language_models/src/provider/lmstudio.rs b/crates/language_models/src/provider/lmstudio.rs index 8f0704c5bc..8cb0829c2a 100644 --- a/crates/language_models/src/provider/lmstudio.rs +++ b/crates/language_models/src/provider/lmstudio.rs @@ -14,10 +14,7 @@ use language_model::{ LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, }; -use lmstudio::{ - ChatCompletionRequest, ChatMessage, ModelType, ResponseStreamEvent, get_models, - stream_chat_completion, -}; +use lmstudio::{ModelType, get_models}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; @@ -49,6 +46,7 @@ pub struct AvailableModel { pub display_name: Option, pub max_tokens: usize, pub supports_tool_calls: bool, + pub supports_images: bool, } pub struct LmStudioLanguageModelProvider { @@ -88,6 +86,7 @@ impl State { .loaded_context_length .or_else(|| model.max_context_length), model.capabilities.supports_tool_calls(), + model.capabilities.supports_images() || model.r#type == ModelType::Vlm, ) }) .collect(); @@ -201,6 +200,7 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider { display_name: model.display_name.clone(), max_tokens: model.max_tokens, supports_tool_calls: model.supports_tool_calls, + supports_images: model.supports_images, }, ); } @@ -244,23 +244,34 @@ pub struct LmStudioLanguageModel { } impl LmStudioLanguageModel { - fn to_lmstudio_request(&self, request: LanguageModelRequest) -> ChatCompletionRequest { + fn to_lmstudio_request( + &self, + request: LanguageModelRequest, + ) -> lmstudio::ChatCompletionRequest { let mut messages = Vec::new(); for message in request.messages { for content in message.content { match content { - MessageContent::Text(text) => messages.push(match message.role { - Role::User => ChatMessage::User { content: text }, - Role::Assistant => ChatMessage::Assistant { - content: Some(text), - tool_calls: Vec::new(), - }, - Role::System => ChatMessage::System { content: text }, - }), + MessageContent::Text(text) => add_message_content_part( + lmstudio::MessagePart::Text { text }, + message.role, + &mut messages, + ), MessageContent::Thinking { .. } => {} MessageContent::RedactedThinking(_) => {} - MessageContent::Image(_) => {} + MessageContent::Image(image) => { + add_message_content_part( + lmstudio::MessagePart::Image { + image_url: lmstudio::ImageUrl { + url: image.to_base64_url(), + detail: None, + }, + }, + message.role, + &mut messages, + ); + } MessageContent::ToolUse(tool_use) => { let tool_call = lmstudio::ToolCall { id: tool_use.id.to_string(), @@ -285,23 +296,32 @@ impl LmStudioLanguageModel { } } MessageContent::ToolResult(tool_result) => { - match &tool_result.content { + let content = match &tool_result.content { LanguageModelToolResultContent::Text(text) => { - messages.push(lmstudio::ChatMessage::Tool { - content: text.to_string(), - tool_call_id: tool_result.tool_use_id.to_string(), - }); + vec![lmstudio::MessagePart::Text { + text: text.to_string(), + }] } - LanguageModelToolResultContent::Image(_) => { - // no support for images for now + LanguageModelToolResultContent::Image(image) => { + vec![lmstudio::MessagePart::Image { + image_url: lmstudio::ImageUrl { + url: image.to_base64_url(), + detail: None, + }, + }] } }; + + messages.push(lmstudio::ChatMessage::Tool { + content: content.into(), + tool_call_id: tool_result.tool_use_id.to_string(), + }); } } } } - ChatCompletionRequest { + lmstudio::ChatCompletionRequest { model: self.model.name.clone(), messages, stream: true, @@ -332,10 +352,12 @@ impl LmStudioLanguageModel { fn stream_completion( &self, - request: ChatCompletionRequest, + request: lmstudio::ChatCompletionRequest, cx: &AsyncApp, - ) -> BoxFuture<'static, Result>>> - { + ) -> BoxFuture< + 'static, + Result>>, + > { let http_client = self.http_client.clone(); let Ok(api_url) = cx.update(|cx| { let settings = &AllLanguageModelSettings::get_global(cx).lmstudio; @@ -345,7 +367,7 @@ impl LmStudioLanguageModel { }; let future = self.request_limiter.stream(async move { - let request = stream_chat_completion(http_client.as_ref(), &api_url, request); + let request = lmstudio::stream_chat_completion(http_client.as_ref(), &api_url, request); let response = request.await?; Ok(response) }); @@ -385,7 +407,7 @@ impl LanguageModel for LmStudioLanguageModel { } fn supports_images(&self) -> bool { - false + self.model.supports_images } fn telemetry_id(&self) -> String { @@ -446,7 +468,7 @@ impl LmStudioEventMapper { pub fn map_stream( mut self, - events: Pin>>>, + events: Pin>>>, ) -> impl Stream> { events.flat_map(move |event| { @@ -459,7 +481,7 @@ impl LmStudioEventMapper { pub fn map_event( &mut self, - event: ResponseStreamEvent, + event: lmstudio::ResponseStreamEvent, ) -> Vec> { let Some(choice) = event.choices.into_iter().next() else { return vec![Err(LanguageModelCompletionError::Other(anyhow!( @@ -551,6 +573,40 @@ struct RawToolCall { arguments: String, } +fn add_message_content_part( + new_part: lmstudio::MessagePart, + role: Role, + messages: &mut Vec, +) { + match (role, messages.last_mut()) { + (Role::User, Some(lmstudio::ChatMessage::User { content })) + | ( + Role::Assistant, + Some(lmstudio::ChatMessage::Assistant { + content: Some(content), + .. + }), + ) + | (Role::System, Some(lmstudio::ChatMessage::System { content })) => { + content.push_part(new_part); + } + _ => { + messages.push(match role { + Role::User => lmstudio::ChatMessage::User { + content: lmstudio::MessageContent::from(vec![new_part]), + }, + Role::Assistant => lmstudio::ChatMessage::Assistant { + content: Some(lmstudio::MessageContent::from(vec![new_part])), + tool_calls: Vec::new(), + }, + Role::System => lmstudio::ChatMessage::System { + content: lmstudio::MessageContent::from(vec![new_part]), + }, + }); + } + } +} + struct ConfigurationView { state: gpui::Entity, loading_models_task: Option>, diff --git a/crates/lmstudio/src/lmstudio.rs b/crates/lmstudio/src/lmstudio.rs index 943f8a2a0d..5c6b610943 100644 --- a/crates/lmstudio/src/lmstudio.rs +++ b/crates/lmstudio/src/lmstudio.rs @@ -48,6 +48,7 @@ pub struct Model { pub display_name: Option, pub max_tokens: usize, pub supports_tool_calls: bool, + pub supports_images: bool, } impl Model { @@ -56,12 +57,14 @@ impl Model { display_name: Option<&str>, max_tokens: Option, supports_tool_calls: bool, + supports_images: bool, ) -> Self { Self { name: name.to_owned(), display_name: display_name.map(|s| s.to_owned()), max_tokens: max_tokens.unwrap_or(2048), supports_tool_calls, + supports_images, } } @@ -110,22 +113,78 @@ pub struct FunctionDefinition { pub enum ChatMessage { Assistant { #[serde(default)] - content: Option, + content: Option, #[serde(default, skip_serializing_if = "Vec::is_empty")] tool_calls: Vec, }, User { - content: String, + content: MessageContent, }, System { - content: String, + content: MessageContent, }, Tool { - content: String, + content: MessageContent, tool_call_id: String, }, } +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(untagged)] +pub enum MessageContent { + Plain(String), + Multipart(Vec), +} + +impl MessageContent { + pub fn empty() -> Self { + MessageContent::Multipart(vec![]) + } + + pub fn push_part(&mut self, part: MessagePart) { + match self { + MessageContent::Plain(text) => { + *self = + MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]); + } + MessageContent::Multipart(parts) if parts.is_empty() => match part { + MessagePart::Text { text } => *self = MessageContent::Plain(text), + MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]), + }, + MessageContent::Multipart(parts) => parts.push(part), + } + } +} + +impl From> for MessageContent { + fn from(mut parts: Vec) -> Self { + if let [MessagePart::Text { text }] = parts.as_mut_slice() { + MessageContent::Plain(std::mem::take(text)) + } else { + MessageContent::Multipart(parts) + } + } +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum MessagePart { + Text { + text: String, + }, + #[serde(rename = "image_url")] + Image { + image_url: ImageUrl, + }, +} + +#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] +pub struct ImageUrl { + pub url: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub detail: Option, +} + #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] pub struct ToolCall { pub id: String, @@ -210,6 +269,10 @@ impl Capabilities { pub fn supports_tool_calls(&self) -> bool { self.0.iter().any(|cap| cap == "tool_use") } + + pub fn supports_images(&self) -> bool { + self.0.iter().any(|cap| cap == "vision") + } } #[derive(Serialize, Deserialize, Debug)] @@ -393,3 +456,38 @@ pub async fn get_models( serde_json::from_str(&body).context("Unable to parse LM Studio models response")?; Ok(response.data) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_image_message_part_serialization() { + let image_part = MessagePart::Image { + image_url: ImageUrl { + url: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg==".to_string(), + detail: None, + }, + }; + + let json = serde_json::to_string(&image_part).unwrap(); + println!("Serialized image part: {}", json); + + // Verify the structure matches what LM Studio expects + let expected_structure = r#"{"type":"image_url","image_url":{"url":"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="}}"#; + assert_eq!(json, expected_structure); + } + + #[test] + fn test_text_message_part_serialization() { + let text_part = MessagePart::Text { + text: "Hello, world!".to_string(), + }; + + let json = serde_json::to_string(&text_part).unwrap(); + println!("Serialized text part: {}", json); + + let expected_structure = r#"{"type":"text","text":"Hello, world!"}"#; + assert_eq!(json, expected_structure); + } +}