language_models: Add images support to LMStudio provider (#32741)
Tested with gemma3:4b LMStudio: beta version 0.3.17 Release Notes: - Add images support to LMStudio provider
This commit is contained in:
parent
6ad9a66cf9
commit
4b88090cca
3 changed files with 190 additions and 34 deletions
|
@ -386,7 +386,9 @@ impl AgentSettingsContent {
|
|||
_ => None,
|
||||
};
|
||||
settings.provider = Some(AgentProviderContentV1::LmStudio {
|
||||
default_model: Some(lmstudio::Model::new(&model, None, None, false)),
|
||||
default_model: Some(lmstudio::Model::new(
|
||||
&model, None, None, false, false,
|
||||
)),
|
||||
api_url,
|
||||
});
|
||||
}
|
||||
|
|
|
@ -14,10 +14,7 @@ use language_model::{
|
|||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||
LanguageModelRequest, RateLimiter, Role,
|
||||
};
|
||||
use lmstudio::{
|
||||
ChatCompletionRequest, ChatMessage, ModelType, ResponseStreamEvent, get_models,
|
||||
stream_chat_completion,
|
||||
};
|
||||
use lmstudio::{ModelType, get_models};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
|
@ -49,6 +46,7 @@ pub struct AvailableModel {
|
|||
pub display_name: Option<String>,
|
||||
pub max_tokens: usize,
|
||||
pub supports_tool_calls: bool,
|
||||
pub supports_images: bool,
|
||||
}
|
||||
|
||||
pub struct LmStudioLanguageModelProvider {
|
||||
|
@ -88,6 +86,7 @@ impl State {
|
|||
.loaded_context_length
|
||||
.or_else(|| model.max_context_length),
|
||||
model.capabilities.supports_tool_calls(),
|
||||
model.capabilities.supports_images() || model.r#type == ModelType::Vlm,
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
@ -201,6 +200,7 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider {
|
|||
display_name: model.display_name.clone(),
|
||||
max_tokens: model.max_tokens,
|
||||
supports_tool_calls: model.supports_tool_calls,
|
||||
supports_images: model.supports_images,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
@ -244,23 +244,34 @@ pub struct LmStudioLanguageModel {
|
|||
}
|
||||
|
||||
impl LmStudioLanguageModel {
|
||||
fn to_lmstudio_request(&self, request: LanguageModelRequest) -> ChatCompletionRequest {
|
||||
fn to_lmstudio_request(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
) -> lmstudio::ChatCompletionRequest {
|
||||
let mut messages = Vec::new();
|
||||
|
||||
for message in request.messages {
|
||||
for content in message.content {
|
||||
match content {
|
||||
MessageContent::Text(text) => messages.push(match message.role {
|
||||
Role::User => ChatMessage::User { content: text },
|
||||
Role::Assistant => ChatMessage::Assistant {
|
||||
content: Some(text),
|
||||
tool_calls: Vec::new(),
|
||||
},
|
||||
Role::System => ChatMessage::System { content: text },
|
||||
}),
|
||||
MessageContent::Text(text) => add_message_content_part(
|
||||
lmstudio::MessagePart::Text { text },
|
||||
message.role,
|
||||
&mut messages,
|
||||
),
|
||||
MessageContent::Thinking { .. } => {}
|
||||
MessageContent::RedactedThinking(_) => {}
|
||||
MessageContent::Image(_) => {}
|
||||
MessageContent::Image(image) => {
|
||||
add_message_content_part(
|
||||
lmstudio::MessagePart::Image {
|
||||
image_url: lmstudio::ImageUrl {
|
||||
url: image.to_base64_url(),
|
||||
detail: None,
|
||||
},
|
||||
},
|
||||
message.role,
|
||||
&mut messages,
|
||||
);
|
||||
}
|
||||
MessageContent::ToolUse(tool_use) => {
|
||||
let tool_call = lmstudio::ToolCall {
|
||||
id: tool_use.id.to_string(),
|
||||
|
@ -285,23 +296,32 @@ impl LmStudioLanguageModel {
|
|||
}
|
||||
}
|
||||
MessageContent::ToolResult(tool_result) => {
|
||||
match &tool_result.content {
|
||||
let content = match &tool_result.content {
|
||||
LanguageModelToolResultContent::Text(text) => {
|
||||
messages.push(lmstudio::ChatMessage::Tool {
|
||||
content: text.to_string(),
|
||||
tool_call_id: tool_result.tool_use_id.to_string(),
|
||||
});
|
||||
vec![lmstudio::MessagePart::Text {
|
||||
text: text.to_string(),
|
||||
}]
|
||||
}
|
||||
LanguageModelToolResultContent::Image(_) => {
|
||||
// no support for images for now
|
||||
LanguageModelToolResultContent::Image(image) => {
|
||||
vec![lmstudio::MessagePart::Image {
|
||||
image_url: lmstudio::ImageUrl {
|
||||
url: image.to_base64_url(),
|
||||
detail: None,
|
||||
},
|
||||
}]
|
||||
}
|
||||
};
|
||||
|
||||
messages.push(lmstudio::ChatMessage::Tool {
|
||||
content: content.into(),
|
||||
tool_call_id: tool_result.tool_use_id.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ChatCompletionRequest {
|
||||
lmstudio::ChatCompletionRequest {
|
||||
model: self.model.name.clone(),
|
||||
messages,
|
||||
stream: true,
|
||||
|
@ -332,10 +352,12 @@ impl LmStudioLanguageModel {
|
|||
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: ChatCompletionRequest,
|
||||
request: lmstudio::ChatCompletionRequest,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
|
||||
{
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<futures::stream::BoxStream<'static, Result<lmstudio::ResponseStreamEvent>>>,
|
||||
> {
|
||||
let http_client = self.http_client.clone();
|
||||
let Ok(api_url) = cx.update(|cx| {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
|
||||
|
@ -345,7 +367,7 @@ impl LmStudioLanguageModel {
|
|||
};
|
||||
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let request = stream_chat_completion(http_client.as_ref(), &api_url, request);
|
||||
let request = lmstudio::stream_chat_completion(http_client.as_ref(), &api_url, request);
|
||||
let response = request.await?;
|
||||
Ok(response)
|
||||
});
|
||||
|
@ -385,7 +407,7 @@ impl LanguageModel for LmStudioLanguageModel {
|
|||
}
|
||||
|
||||
fn supports_images(&self) -> bool {
|
||||
false
|
||||
self.model.supports_images
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String {
|
||||
|
@ -446,7 +468,7 @@ impl LmStudioEventMapper {
|
|||
|
||||
pub fn map_stream(
|
||||
mut self,
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<lmstudio::ResponseStreamEvent>>>>,
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
|
||||
{
|
||||
events.flat_map(move |event| {
|
||||
|
@ -459,7 +481,7 @@ impl LmStudioEventMapper {
|
|||
|
||||
pub fn map_event(
|
||||
&mut self,
|
||||
event: ResponseStreamEvent,
|
||||
event: lmstudio::ResponseStreamEvent,
|
||||
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
let Some(choice) = event.choices.into_iter().next() else {
|
||||
return vec![Err(LanguageModelCompletionError::Other(anyhow!(
|
||||
|
@ -551,6 +573,40 @@ struct RawToolCall {
|
|||
arguments: String,
|
||||
}
|
||||
|
||||
fn add_message_content_part(
|
||||
new_part: lmstudio::MessagePart,
|
||||
role: Role,
|
||||
messages: &mut Vec<lmstudio::ChatMessage>,
|
||||
) {
|
||||
match (role, messages.last_mut()) {
|
||||
(Role::User, Some(lmstudio::ChatMessage::User { content }))
|
||||
| (
|
||||
Role::Assistant,
|
||||
Some(lmstudio::ChatMessage::Assistant {
|
||||
content: Some(content),
|
||||
..
|
||||
}),
|
||||
)
|
||||
| (Role::System, Some(lmstudio::ChatMessage::System { content })) => {
|
||||
content.push_part(new_part);
|
||||
}
|
||||
_ => {
|
||||
messages.push(match role {
|
||||
Role::User => lmstudio::ChatMessage::User {
|
||||
content: lmstudio::MessageContent::from(vec![new_part]),
|
||||
},
|
||||
Role::Assistant => lmstudio::ChatMessage::Assistant {
|
||||
content: Some(lmstudio::MessageContent::from(vec![new_part])),
|
||||
tool_calls: Vec::new(),
|
||||
},
|
||||
Role::System => lmstudio::ChatMessage::System {
|
||||
content: lmstudio::MessageContent::from(vec![new_part]),
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ConfigurationView {
|
||||
state: gpui::Entity<State>,
|
||||
loading_models_task: Option<Task<()>>,
|
||||
|
|
|
@ -48,6 +48,7 @@ pub struct Model {
|
|||
pub display_name: Option<String>,
|
||||
pub max_tokens: usize,
|
||||
pub supports_tool_calls: bool,
|
||||
pub supports_images: bool,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
|
@ -56,12 +57,14 @@ impl Model {
|
|||
display_name: Option<&str>,
|
||||
max_tokens: Option<usize>,
|
||||
supports_tool_calls: bool,
|
||||
supports_images: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
name: name.to_owned(),
|
||||
display_name: display_name.map(|s| s.to_owned()),
|
||||
max_tokens: max_tokens.unwrap_or(2048),
|
||||
supports_tool_calls,
|
||||
supports_images,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -110,22 +113,78 @@ pub struct FunctionDefinition {
|
|||
pub enum ChatMessage {
|
||||
Assistant {
|
||||
#[serde(default)]
|
||||
content: Option<String>,
|
||||
content: Option<MessageContent>,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
tool_calls: Vec<ToolCall>,
|
||||
},
|
||||
User {
|
||||
content: String,
|
||||
content: MessageContent,
|
||||
},
|
||||
System {
|
||||
content: String,
|
||||
content: MessageContent,
|
||||
},
|
||||
Tool {
|
||||
content: String,
|
||||
content: MessageContent,
|
||||
tool_call_id: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(untagged)]
|
||||
pub enum MessageContent {
|
||||
Plain(String),
|
||||
Multipart(Vec<MessagePart>),
|
||||
}
|
||||
|
||||
impl MessageContent {
|
||||
pub fn empty() -> Self {
|
||||
MessageContent::Multipart(vec![])
|
||||
}
|
||||
|
||||
pub fn push_part(&mut self, part: MessagePart) {
|
||||
match self {
|
||||
MessageContent::Plain(text) => {
|
||||
*self =
|
||||
MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
|
||||
}
|
||||
MessageContent::Multipart(parts) if parts.is_empty() => match part {
|
||||
MessagePart::Text { text } => *self = MessageContent::Plain(text),
|
||||
MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
|
||||
},
|
||||
MessageContent::Multipart(parts) => parts.push(part),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<MessagePart>> for MessageContent {
|
||||
fn from(mut parts: Vec<MessagePart>) -> Self {
|
||||
if let [MessagePart::Text { text }] = parts.as_mut_slice() {
|
||||
MessageContent::Plain(std::mem::take(text))
|
||||
} else {
|
||||
MessageContent::Multipart(parts)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum MessagePart {
|
||||
Text {
|
||||
text: String,
|
||||
},
|
||||
#[serde(rename = "image_url")]
|
||||
Image {
|
||||
image_url: ImageUrl,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
|
||||
pub struct ImageUrl {
|
||||
pub url: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub detail: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct ToolCall {
|
||||
pub id: String,
|
||||
|
@ -210,6 +269,10 @@ impl Capabilities {
|
|||
pub fn supports_tool_calls(&self) -> bool {
|
||||
self.0.iter().any(|cap| cap == "tool_use")
|
||||
}
|
||||
|
||||
pub fn supports_images(&self) -> bool {
|
||||
self.0.iter().any(|cap| cap == "vision")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
|
@ -393,3 +456,38 @@ pub async fn get_models(
|
|||
serde_json::from_str(&body).context("Unable to parse LM Studio models response")?;
|
||||
Ok(response.data)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_image_message_part_serialization() {
|
||||
let image_part = MessagePart::Image {
|
||||
image_url: ImageUrl {
|
||||
url: "".to_string(),
|
||||
detail: None,
|
||||
},
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&image_part).unwrap();
|
||||
println!("Serialized image part: {}", json);
|
||||
|
||||
// Verify the structure matches what LM Studio expects
|
||||
let expected_structure = r#"{"type":"image_url","image_url":{"url":""}}"#;
|
||||
assert_eq!(json, expected_structure);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_text_message_part_serialization() {
|
||||
let text_part = MessagePart::Text {
|
||||
text: "Hello, world!".to_string(),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&text_part).unwrap();
|
||||
println!("Serialized text part: {}", json);
|
||||
|
||||
let expected_structure = r#"{"type":"text","text":"Hello, world!"}"#;
|
||||
assert_eq!(json, expected_structure);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue