language_models: Add images support to LMStudio provider (#32741)
Tested with gemma3:4b LMStudio: beta version 0.3.17 Release Notes: - Add images support to LMStudio provider
This commit is contained in:
parent
6ad9a66cf9
commit
4b88090cca
3 changed files with 190 additions and 34 deletions
|
@ -386,7 +386,9 @@ impl AgentSettingsContent {
|
||||||
_ => None,
|
_ => None,
|
||||||
};
|
};
|
||||||
settings.provider = Some(AgentProviderContentV1::LmStudio {
|
settings.provider = Some(AgentProviderContentV1::LmStudio {
|
||||||
default_model: Some(lmstudio::Model::new(&model, None, None, false)),
|
default_model: Some(lmstudio::Model::new(
|
||||||
|
&model, None, None, false, false,
|
||||||
|
)),
|
||||||
api_url,
|
api_url,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,10 +14,7 @@ use language_model::{
|
||||||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||||
LanguageModelRequest, RateLimiter, Role,
|
LanguageModelRequest, RateLimiter, Role,
|
||||||
};
|
};
|
||||||
use lmstudio::{
|
use lmstudio::{ModelType, get_models};
|
||||||
ChatCompletionRequest, ChatMessage, ModelType, ResponseStreamEvent, get_models,
|
|
||||||
stream_chat_completion,
|
|
||||||
};
|
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use settings::{Settings, SettingsStore};
|
use settings::{Settings, SettingsStore};
|
||||||
|
@ -49,6 +46,7 @@ pub struct AvailableModel {
|
||||||
pub display_name: Option<String>,
|
pub display_name: Option<String>,
|
||||||
pub max_tokens: usize,
|
pub max_tokens: usize,
|
||||||
pub supports_tool_calls: bool,
|
pub supports_tool_calls: bool,
|
||||||
|
pub supports_images: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct LmStudioLanguageModelProvider {
|
pub struct LmStudioLanguageModelProvider {
|
||||||
|
@ -88,6 +86,7 @@ impl State {
|
||||||
.loaded_context_length
|
.loaded_context_length
|
||||||
.or_else(|| model.max_context_length),
|
.or_else(|| model.max_context_length),
|
||||||
model.capabilities.supports_tool_calls(),
|
model.capabilities.supports_tool_calls(),
|
||||||
|
model.capabilities.supports_images() || model.r#type == ModelType::Vlm,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
@ -201,6 +200,7 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider {
|
||||||
display_name: model.display_name.clone(),
|
display_name: model.display_name.clone(),
|
||||||
max_tokens: model.max_tokens,
|
max_tokens: model.max_tokens,
|
||||||
supports_tool_calls: model.supports_tool_calls,
|
supports_tool_calls: model.supports_tool_calls,
|
||||||
|
supports_images: model.supports_images,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -244,23 +244,34 @@ pub struct LmStudioLanguageModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LmStudioLanguageModel {
|
impl LmStudioLanguageModel {
|
||||||
fn to_lmstudio_request(&self, request: LanguageModelRequest) -> ChatCompletionRequest {
|
fn to_lmstudio_request(
|
||||||
|
&self,
|
||||||
|
request: LanguageModelRequest,
|
||||||
|
) -> lmstudio::ChatCompletionRequest {
|
||||||
let mut messages = Vec::new();
|
let mut messages = Vec::new();
|
||||||
|
|
||||||
for message in request.messages {
|
for message in request.messages {
|
||||||
for content in message.content {
|
for content in message.content {
|
||||||
match content {
|
match content {
|
||||||
MessageContent::Text(text) => messages.push(match message.role {
|
MessageContent::Text(text) => add_message_content_part(
|
||||||
Role::User => ChatMessage::User { content: text },
|
lmstudio::MessagePart::Text { text },
|
||||||
Role::Assistant => ChatMessage::Assistant {
|
message.role,
|
||||||
content: Some(text),
|
&mut messages,
|
||||||
tool_calls: Vec::new(),
|
),
|
||||||
},
|
|
||||||
Role::System => ChatMessage::System { content: text },
|
|
||||||
}),
|
|
||||||
MessageContent::Thinking { .. } => {}
|
MessageContent::Thinking { .. } => {}
|
||||||
MessageContent::RedactedThinking(_) => {}
|
MessageContent::RedactedThinking(_) => {}
|
||||||
MessageContent::Image(_) => {}
|
MessageContent::Image(image) => {
|
||||||
|
add_message_content_part(
|
||||||
|
lmstudio::MessagePart::Image {
|
||||||
|
image_url: lmstudio::ImageUrl {
|
||||||
|
url: image.to_base64_url(),
|
||||||
|
detail: None,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
message.role,
|
||||||
|
&mut messages,
|
||||||
|
);
|
||||||
|
}
|
||||||
MessageContent::ToolUse(tool_use) => {
|
MessageContent::ToolUse(tool_use) => {
|
||||||
let tool_call = lmstudio::ToolCall {
|
let tool_call = lmstudio::ToolCall {
|
||||||
id: tool_use.id.to_string(),
|
id: tool_use.id.to_string(),
|
||||||
|
@ -285,23 +296,32 @@ impl LmStudioLanguageModel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
MessageContent::ToolResult(tool_result) => {
|
MessageContent::ToolResult(tool_result) => {
|
||||||
match &tool_result.content {
|
let content = match &tool_result.content {
|
||||||
LanguageModelToolResultContent::Text(text) => {
|
LanguageModelToolResultContent::Text(text) => {
|
||||||
messages.push(lmstudio::ChatMessage::Tool {
|
vec![lmstudio::MessagePart::Text {
|
||||||
content: text.to_string(),
|
text: text.to_string(),
|
||||||
tool_call_id: tool_result.tool_use_id.to_string(),
|
}]
|
||||||
});
|
|
||||||
}
|
}
|
||||||
LanguageModelToolResultContent::Image(_) => {
|
LanguageModelToolResultContent::Image(image) => {
|
||||||
// no support for images for now
|
vec![lmstudio::MessagePart::Image {
|
||||||
|
image_url: lmstudio::ImageUrl {
|
||||||
|
url: image.to_base64_url(),
|
||||||
|
detail: None,
|
||||||
|
},
|
||||||
|
}]
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
messages.push(lmstudio::ChatMessage::Tool {
|
||||||
|
content: content.into(),
|
||||||
|
tool_call_id: tool_result.tool_use_id.to_string(),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ChatCompletionRequest {
|
lmstudio::ChatCompletionRequest {
|
||||||
model: self.model.name.clone(),
|
model: self.model.name.clone(),
|
||||||
messages,
|
messages,
|
||||||
stream: true,
|
stream: true,
|
||||||
|
@ -332,10 +352,12 @@ impl LmStudioLanguageModel {
|
||||||
|
|
||||||
fn stream_completion(
|
fn stream_completion(
|
||||||
&self,
|
&self,
|
||||||
request: ChatCompletionRequest,
|
request: lmstudio::ChatCompletionRequest,
|
||||||
cx: &AsyncApp,
|
cx: &AsyncApp,
|
||||||
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
|
) -> BoxFuture<
|
||||||
{
|
'static,
|
||||||
|
Result<futures::stream::BoxStream<'static, Result<lmstudio::ResponseStreamEvent>>>,
|
||||||
|
> {
|
||||||
let http_client = self.http_client.clone();
|
let http_client = self.http_client.clone();
|
||||||
let Ok(api_url) = cx.update(|cx| {
|
let Ok(api_url) = cx.update(|cx| {
|
||||||
let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
|
let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
|
||||||
|
@ -345,7 +367,7 @@ impl LmStudioLanguageModel {
|
||||||
};
|
};
|
||||||
|
|
||||||
let future = self.request_limiter.stream(async move {
|
let future = self.request_limiter.stream(async move {
|
||||||
let request = stream_chat_completion(http_client.as_ref(), &api_url, request);
|
let request = lmstudio::stream_chat_completion(http_client.as_ref(), &api_url, request);
|
||||||
let response = request.await?;
|
let response = request.await?;
|
||||||
Ok(response)
|
Ok(response)
|
||||||
});
|
});
|
||||||
|
@ -385,7 +407,7 @@ impl LanguageModel for LmStudioLanguageModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn supports_images(&self) -> bool {
|
fn supports_images(&self) -> bool {
|
||||||
false
|
self.model.supports_images
|
||||||
}
|
}
|
||||||
|
|
||||||
fn telemetry_id(&self) -> String {
|
fn telemetry_id(&self) -> String {
|
||||||
|
@ -446,7 +468,7 @@ impl LmStudioEventMapper {
|
||||||
|
|
||||||
pub fn map_stream(
|
pub fn map_stream(
|
||||||
mut self,
|
mut self,
|
||||||
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
|
events: Pin<Box<dyn Send + Stream<Item = Result<lmstudio::ResponseStreamEvent>>>>,
|
||||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
|
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
|
||||||
{
|
{
|
||||||
events.flat_map(move |event| {
|
events.flat_map(move |event| {
|
||||||
|
@ -459,7 +481,7 @@ impl LmStudioEventMapper {
|
||||||
|
|
||||||
pub fn map_event(
|
pub fn map_event(
|
||||||
&mut self,
|
&mut self,
|
||||||
event: ResponseStreamEvent,
|
event: lmstudio::ResponseStreamEvent,
|
||||||
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||||
let Some(choice) = event.choices.into_iter().next() else {
|
let Some(choice) = event.choices.into_iter().next() else {
|
||||||
return vec![Err(LanguageModelCompletionError::Other(anyhow!(
|
return vec![Err(LanguageModelCompletionError::Other(anyhow!(
|
||||||
|
@ -551,6 +573,40 @@ struct RawToolCall {
|
||||||
arguments: String,
|
arguments: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn add_message_content_part(
|
||||||
|
new_part: lmstudio::MessagePart,
|
||||||
|
role: Role,
|
||||||
|
messages: &mut Vec<lmstudio::ChatMessage>,
|
||||||
|
) {
|
||||||
|
match (role, messages.last_mut()) {
|
||||||
|
(Role::User, Some(lmstudio::ChatMessage::User { content }))
|
||||||
|
| (
|
||||||
|
Role::Assistant,
|
||||||
|
Some(lmstudio::ChatMessage::Assistant {
|
||||||
|
content: Some(content),
|
||||||
|
..
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
| (Role::System, Some(lmstudio::ChatMessage::System { content })) => {
|
||||||
|
content.push_part(new_part);
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
messages.push(match role {
|
||||||
|
Role::User => lmstudio::ChatMessage::User {
|
||||||
|
content: lmstudio::MessageContent::from(vec![new_part]),
|
||||||
|
},
|
||||||
|
Role::Assistant => lmstudio::ChatMessage::Assistant {
|
||||||
|
content: Some(lmstudio::MessageContent::from(vec![new_part])),
|
||||||
|
tool_calls: Vec::new(),
|
||||||
|
},
|
||||||
|
Role::System => lmstudio::ChatMessage::System {
|
||||||
|
content: lmstudio::MessageContent::from(vec![new_part]),
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct ConfigurationView {
|
struct ConfigurationView {
|
||||||
state: gpui::Entity<State>,
|
state: gpui::Entity<State>,
|
||||||
loading_models_task: Option<Task<()>>,
|
loading_models_task: Option<Task<()>>,
|
||||||
|
|
|
@ -48,6 +48,7 @@ pub struct Model {
|
||||||
pub display_name: Option<String>,
|
pub display_name: Option<String>,
|
||||||
pub max_tokens: usize,
|
pub max_tokens: usize,
|
||||||
pub supports_tool_calls: bool,
|
pub supports_tool_calls: bool,
|
||||||
|
pub supports_images: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Model {
|
impl Model {
|
||||||
|
@ -56,12 +57,14 @@ impl Model {
|
||||||
display_name: Option<&str>,
|
display_name: Option<&str>,
|
||||||
max_tokens: Option<usize>,
|
max_tokens: Option<usize>,
|
||||||
supports_tool_calls: bool,
|
supports_tool_calls: bool,
|
||||||
|
supports_images: bool,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
name: name.to_owned(),
|
name: name.to_owned(),
|
||||||
display_name: display_name.map(|s| s.to_owned()),
|
display_name: display_name.map(|s| s.to_owned()),
|
||||||
max_tokens: max_tokens.unwrap_or(2048),
|
max_tokens: max_tokens.unwrap_or(2048),
|
||||||
supports_tool_calls,
|
supports_tool_calls,
|
||||||
|
supports_images,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -110,22 +113,78 @@ pub struct FunctionDefinition {
|
||||||
pub enum ChatMessage {
|
pub enum ChatMessage {
|
||||||
Assistant {
|
Assistant {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
content: Option<String>,
|
content: Option<MessageContent>,
|
||||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||||
tool_calls: Vec<ToolCall>,
|
tool_calls: Vec<ToolCall>,
|
||||||
},
|
},
|
||||||
User {
|
User {
|
||||||
content: String,
|
content: MessageContent,
|
||||||
},
|
},
|
||||||
System {
|
System {
|
||||||
content: String,
|
content: MessageContent,
|
||||||
},
|
},
|
||||||
Tool {
|
Tool {
|
||||||
content: String,
|
content: MessageContent,
|
||||||
tool_call_id: String,
|
tool_call_id: String,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum MessageContent {
|
||||||
|
Plain(String),
|
||||||
|
Multipart(Vec<MessagePart>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MessageContent {
|
||||||
|
pub fn empty() -> Self {
|
||||||
|
MessageContent::Multipart(vec![])
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn push_part(&mut self, part: MessagePart) {
|
||||||
|
match self {
|
||||||
|
MessageContent::Plain(text) => {
|
||||||
|
*self =
|
||||||
|
MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
|
||||||
|
}
|
||||||
|
MessageContent::Multipart(parts) if parts.is_empty() => match part {
|
||||||
|
MessagePart::Text { text } => *self = MessageContent::Plain(text),
|
||||||
|
MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
|
||||||
|
},
|
||||||
|
MessageContent::Multipart(parts) => parts.push(part),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Vec<MessagePart>> for MessageContent {
|
||||||
|
fn from(mut parts: Vec<MessagePart>) -> Self {
|
||||||
|
if let [MessagePart::Text { text }] = parts.as_mut_slice() {
|
||||||
|
MessageContent::Plain(std::mem::take(text))
|
||||||
|
} else {
|
||||||
|
MessageContent::Multipart(parts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
|
pub enum MessagePart {
|
||||||
|
Text {
|
||||||
|
text: String,
|
||||||
|
},
|
||||||
|
#[serde(rename = "image_url")]
|
||||||
|
Image {
|
||||||
|
image_url: ImageUrl,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
|
||||||
|
pub struct ImageUrl {
|
||||||
|
pub url: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub detail: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
pub struct ToolCall {
|
pub struct ToolCall {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
|
@ -210,6 +269,10 @@ impl Capabilities {
|
||||||
pub fn supports_tool_calls(&self) -> bool {
|
pub fn supports_tool_calls(&self) -> bool {
|
||||||
self.0.iter().any(|cap| cap == "tool_use")
|
self.0.iter().any(|cap| cap == "tool_use")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn supports_images(&self) -> bool {
|
||||||
|
self.0.iter().any(|cap| cap == "vision")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
@ -393,3 +456,38 @@ pub async fn get_models(
|
||||||
serde_json::from_str(&body).context("Unable to parse LM Studio models response")?;
|
serde_json::from_str(&body).context("Unable to parse LM Studio models response")?;
|
||||||
Ok(response.data)
|
Ok(response.data)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_image_message_part_serialization() {
|
||||||
|
let image_part = MessagePart::Image {
|
||||||
|
image_url: ImageUrl {
|
||||||
|
url: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg==".to_string(),
|
||||||
|
detail: None,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
let json = serde_json::to_string(&image_part).unwrap();
|
||||||
|
println!("Serialized image part: {}", json);
|
||||||
|
|
||||||
|
// Verify the structure matches what LM Studio expects
|
||||||
|
let expected_structure = r#"{"type":"image_url","image_url":{"url":"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="}}"#;
|
||||||
|
assert_eq!(json, expected_structure);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_text_message_part_serialization() {
|
||||||
|
let text_part = MessagePart::Text {
|
||||||
|
text: "Hello, world!".to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let json = serde_json::to_string(&text_part).unwrap();
|
||||||
|
println!("Serialized text part: {}", json);
|
||||||
|
|
||||||
|
let expected_structure = r#"{"type":"text","text":"Hello, world!"}"#;
|
||||||
|
assert_eq!(json, expected_structure);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue