language_models: Add images support to LMStudio provider (#32741)

Tested with gemma3:4b
LMStudio: beta version 0.3.17

Release Notes:

- Add images support to LMStudio provider
This commit is contained in:
Umesh Yadav 2025-06-17 15:44:44 +05:30 committed by GitHub
parent 6ad9a66cf9
commit 4b88090cca
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 190 additions and 34 deletions

View file

@ -386,7 +386,9 @@ impl AgentSettingsContent {
_ => None, _ => None,
}; };
settings.provider = Some(AgentProviderContentV1::LmStudio { settings.provider = Some(AgentProviderContentV1::LmStudio {
default_model: Some(lmstudio::Model::new(&model, None, None, false)), default_model: Some(lmstudio::Model::new(
&model, None, None, false, false,
)),
api_url, api_url,
}); });
} }

View file

@ -14,10 +14,7 @@ use language_model::{
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, RateLimiter, Role, LanguageModelRequest, RateLimiter, Role,
}; };
use lmstudio::{ use lmstudio::{ModelType, get_models};
ChatCompletionRequest, ChatMessage, ModelType, ResponseStreamEvent, get_models,
stream_chat_completion,
};
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore}; use settings::{Settings, SettingsStore};
@ -49,6 +46,7 @@ pub struct AvailableModel {
pub display_name: Option<String>, pub display_name: Option<String>,
pub max_tokens: usize, pub max_tokens: usize,
pub supports_tool_calls: bool, pub supports_tool_calls: bool,
pub supports_images: bool,
} }
pub struct LmStudioLanguageModelProvider { pub struct LmStudioLanguageModelProvider {
@ -88,6 +86,7 @@ impl State {
.loaded_context_length .loaded_context_length
.or_else(|| model.max_context_length), .or_else(|| model.max_context_length),
model.capabilities.supports_tool_calls(), model.capabilities.supports_tool_calls(),
model.capabilities.supports_images() || model.r#type == ModelType::Vlm,
) )
}) })
.collect(); .collect();
@ -201,6 +200,7 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider {
display_name: model.display_name.clone(), display_name: model.display_name.clone(),
max_tokens: model.max_tokens, max_tokens: model.max_tokens,
supports_tool_calls: model.supports_tool_calls, supports_tool_calls: model.supports_tool_calls,
supports_images: model.supports_images,
}, },
); );
} }
@ -244,23 +244,34 @@ pub struct LmStudioLanguageModel {
} }
impl LmStudioLanguageModel { impl LmStudioLanguageModel {
fn to_lmstudio_request(&self, request: LanguageModelRequest) -> ChatCompletionRequest { fn to_lmstudio_request(
&self,
request: LanguageModelRequest,
) -> lmstudio::ChatCompletionRequest {
let mut messages = Vec::new(); let mut messages = Vec::new();
for message in request.messages { for message in request.messages {
for content in message.content { for content in message.content {
match content { match content {
MessageContent::Text(text) => messages.push(match message.role { MessageContent::Text(text) => add_message_content_part(
Role::User => ChatMessage::User { content: text }, lmstudio::MessagePart::Text { text },
Role::Assistant => ChatMessage::Assistant { message.role,
content: Some(text), &mut messages,
tool_calls: Vec::new(), ),
},
Role::System => ChatMessage::System { content: text },
}),
MessageContent::Thinking { .. } => {} MessageContent::Thinking { .. } => {}
MessageContent::RedactedThinking(_) => {} MessageContent::RedactedThinking(_) => {}
MessageContent::Image(_) => {} MessageContent::Image(image) => {
add_message_content_part(
lmstudio::MessagePart::Image {
image_url: lmstudio::ImageUrl {
url: image.to_base64_url(),
detail: None,
},
},
message.role,
&mut messages,
);
}
MessageContent::ToolUse(tool_use) => { MessageContent::ToolUse(tool_use) => {
let tool_call = lmstudio::ToolCall { let tool_call = lmstudio::ToolCall {
id: tool_use.id.to_string(), id: tool_use.id.to_string(),
@ -285,23 +296,32 @@ impl LmStudioLanguageModel {
} }
} }
MessageContent::ToolResult(tool_result) => { MessageContent::ToolResult(tool_result) => {
match &tool_result.content { let content = match &tool_result.content {
LanguageModelToolResultContent::Text(text) => { LanguageModelToolResultContent::Text(text) => {
messages.push(lmstudio::ChatMessage::Tool { vec![lmstudio::MessagePart::Text {
content: text.to_string(), text: text.to_string(),
tool_call_id: tool_result.tool_use_id.to_string(), }]
});
} }
LanguageModelToolResultContent::Image(_) => { LanguageModelToolResultContent::Image(image) => {
// no support for images for now vec![lmstudio::MessagePart::Image {
image_url: lmstudio::ImageUrl {
url: image.to_base64_url(),
detail: None,
},
}]
} }
}; };
messages.push(lmstudio::ChatMessage::Tool {
content: content.into(),
tool_call_id: tool_result.tool_use_id.to_string(),
});
} }
} }
} }
} }
ChatCompletionRequest { lmstudio::ChatCompletionRequest {
model: self.model.name.clone(), model: self.model.name.clone(),
messages, messages,
stream: true, stream: true,
@ -332,10 +352,12 @@ impl LmStudioLanguageModel {
fn stream_completion( fn stream_completion(
&self, &self,
request: ChatCompletionRequest, request: lmstudio::ChatCompletionRequest,
cx: &AsyncApp, cx: &AsyncApp,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>> ) -> BoxFuture<
{ 'static,
Result<futures::stream::BoxStream<'static, Result<lmstudio::ResponseStreamEvent>>>,
> {
let http_client = self.http_client.clone(); let http_client = self.http_client.clone();
let Ok(api_url) = cx.update(|cx| { let Ok(api_url) = cx.update(|cx| {
let settings = &AllLanguageModelSettings::get_global(cx).lmstudio; let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
@ -345,7 +367,7 @@ impl LmStudioLanguageModel {
}; };
let future = self.request_limiter.stream(async move { let future = self.request_limiter.stream(async move {
let request = stream_chat_completion(http_client.as_ref(), &api_url, request); let request = lmstudio::stream_chat_completion(http_client.as_ref(), &api_url, request);
let response = request.await?; let response = request.await?;
Ok(response) Ok(response)
}); });
@ -385,7 +407,7 @@ impl LanguageModel for LmStudioLanguageModel {
} }
fn supports_images(&self) -> bool { fn supports_images(&self) -> bool {
false self.model.supports_images
} }
fn telemetry_id(&self) -> String { fn telemetry_id(&self) -> String {
@ -446,7 +468,7 @@ impl LmStudioEventMapper {
pub fn map_stream( pub fn map_stream(
mut self, mut self,
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>, events: Pin<Box<dyn Send + Stream<Item = Result<lmstudio::ResponseStreamEvent>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
{ {
events.flat_map(move |event| { events.flat_map(move |event| {
@ -459,7 +481,7 @@ impl LmStudioEventMapper {
pub fn map_event( pub fn map_event(
&mut self, &mut self,
event: ResponseStreamEvent, event: lmstudio::ResponseStreamEvent,
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> { ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
let Some(choice) = event.choices.into_iter().next() else { let Some(choice) = event.choices.into_iter().next() else {
return vec![Err(LanguageModelCompletionError::Other(anyhow!( return vec![Err(LanguageModelCompletionError::Other(anyhow!(
@ -551,6 +573,40 @@ struct RawToolCall {
arguments: String, arguments: String,
} }
fn add_message_content_part(
new_part: lmstudio::MessagePart,
role: Role,
messages: &mut Vec<lmstudio::ChatMessage>,
) {
match (role, messages.last_mut()) {
(Role::User, Some(lmstudio::ChatMessage::User { content }))
| (
Role::Assistant,
Some(lmstudio::ChatMessage::Assistant {
content: Some(content),
..
}),
)
| (Role::System, Some(lmstudio::ChatMessage::System { content })) => {
content.push_part(new_part);
}
_ => {
messages.push(match role {
Role::User => lmstudio::ChatMessage::User {
content: lmstudio::MessageContent::from(vec![new_part]),
},
Role::Assistant => lmstudio::ChatMessage::Assistant {
content: Some(lmstudio::MessageContent::from(vec![new_part])),
tool_calls: Vec::new(),
},
Role::System => lmstudio::ChatMessage::System {
content: lmstudio::MessageContent::from(vec![new_part]),
},
});
}
}
}
struct ConfigurationView { struct ConfigurationView {
state: gpui::Entity<State>, state: gpui::Entity<State>,
loading_models_task: Option<Task<()>>, loading_models_task: Option<Task<()>>,

View file

@ -48,6 +48,7 @@ pub struct Model {
pub display_name: Option<String>, pub display_name: Option<String>,
pub max_tokens: usize, pub max_tokens: usize,
pub supports_tool_calls: bool, pub supports_tool_calls: bool,
pub supports_images: bool,
} }
impl Model { impl Model {
@ -56,12 +57,14 @@ impl Model {
display_name: Option<&str>, display_name: Option<&str>,
max_tokens: Option<usize>, max_tokens: Option<usize>,
supports_tool_calls: bool, supports_tool_calls: bool,
supports_images: bool,
) -> Self { ) -> Self {
Self { Self {
name: name.to_owned(), name: name.to_owned(),
display_name: display_name.map(|s| s.to_owned()), display_name: display_name.map(|s| s.to_owned()),
max_tokens: max_tokens.unwrap_or(2048), max_tokens: max_tokens.unwrap_or(2048),
supports_tool_calls, supports_tool_calls,
supports_images,
} }
} }
@ -110,22 +113,78 @@ pub struct FunctionDefinition {
pub enum ChatMessage { pub enum ChatMessage {
Assistant { Assistant {
#[serde(default)] #[serde(default)]
content: Option<String>, content: Option<MessageContent>,
#[serde(default, skip_serializing_if = "Vec::is_empty")] #[serde(default, skip_serializing_if = "Vec::is_empty")]
tool_calls: Vec<ToolCall>, tool_calls: Vec<ToolCall>,
}, },
User { User {
content: String, content: MessageContent,
}, },
System { System {
content: String, content: MessageContent,
}, },
Tool { Tool {
content: String, content: MessageContent,
tool_call_id: String, tool_call_id: String,
}, },
} }
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(untagged)]
pub enum MessageContent {
Plain(String),
Multipart(Vec<MessagePart>),
}
impl MessageContent {
pub fn empty() -> Self {
MessageContent::Multipart(vec![])
}
pub fn push_part(&mut self, part: MessagePart) {
match self {
MessageContent::Plain(text) => {
*self =
MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
}
MessageContent::Multipart(parts) if parts.is_empty() => match part {
MessagePart::Text { text } => *self = MessageContent::Plain(text),
MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
},
MessageContent::Multipart(parts) => parts.push(part),
}
}
}
impl From<Vec<MessagePart>> for MessageContent {
fn from(mut parts: Vec<MessagePart>) -> Self {
if let [MessagePart::Text { text }] = parts.as_mut_slice() {
MessageContent::Plain(std::mem::take(text))
} else {
MessageContent::Multipart(parts)
}
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum MessagePart {
Text {
text: String,
},
#[serde(rename = "image_url")]
Image {
image_url: ImageUrl,
},
}
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
pub struct ImageUrl {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub detail: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct ToolCall { pub struct ToolCall {
pub id: String, pub id: String,
@ -210,6 +269,10 @@ impl Capabilities {
pub fn supports_tool_calls(&self) -> bool { pub fn supports_tool_calls(&self) -> bool {
self.0.iter().any(|cap| cap == "tool_use") self.0.iter().any(|cap| cap == "tool_use")
} }
pub fn supports_images(&self) -> bool {
self.0.iter().any(|cap| cap == "vision")
}
} }
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
@ -393,3 +456,38 @@ pub async fn get_models(
serde_json::from_str(&body).context("Unable to parse LM Studio models response")?; serde_json::from_str(&body).context("Unable to parse LM Studio models response")?;
Ok(response.data) Ok(response.data)
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_image_message_part_serialization() {
let image_part = MessagePart::Image {
image_url: ImageUrl {
url: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg==".to_string(),
detail: None,
},
};
let json = serde_json::to_string(&image_part).unwrap();
println!("Serialized image part: {}", json);
// Verify the structure matches what LM Studio expects
let expected_structure = r#"{"type":"image_url","image_url":{"url":"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="}}"#;
assert_eq!(json, expected_structure);
}
#[test]
fn test_text_message_part_serialization() {
let text_part = MessagePart::Text {
text: "Hello, world!".to_string(),
};
let json = serde_json::to_string(&text_part).unwrap();
println!("Serialized text part: {}", json);
let expected_structure = r#"{"type":"text","text":"Hello, world!"}"#;
assert_eq!(json, expected_structure);
}
}