
Discovered in this issue: #34513 Previously, we were propagating deserialization errors to users when using LMStudio, instead of the actual error message sent from the LMStudio server. This change will help users understand why their request failed while streaming responses. Release Notes: - lmsudio: Display specific backend error messaging on failure rather than generic ones --------- Signed-off-by: Umesh Yadav <git@umesh.dev> Co-authored-by: Peter Tripp <peter@zed.dev>
495 lines
14 KiB
Rust
495 lines
14 KiB
Rust
use anyhow::{Context as _, Result, anyhow};
|
|
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
|
|
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
|
|
use serde::{Deserialize, Serialize};
|
|
use serde_json::Value;
|
|
use std::{convert::TryFrom, time::Duration};
|
|
|
|
pub const LMSTUDIO_API_URL: &str = "http://localhost:1234/api/v0";
|
|
|
|
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
|
#[serde(rename_all = "lowercase")]
|
|
pub enum Role {
|
|
User,
|
|
Assistant,
|
|
System,
|
|
Tool,
|
|
}
|
|
|
|
impl TryFrom<String> for Role {
|
|
type Error = anyhow::Error;
|
|
|
|
fn try_from(value: String) -> Result<Self> {
|
|
match value.as_str() {
|
|
"user" => Ok(Self::User),
|
|
"assistant" => Ok(Self::Assistant),
|
|
"system" => Ok(Self::System),
|
|
"tool" => Ok(Self::Tool),
|
|
_ => anyhow::bail!("invalid role '{value}'"),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl From<Role> for String {
|
|
fn from(val: Role) -> Self {
|
|
match val {
|
|
Role::User => "user".to_owned(),
|
|
Role::Assistant => "assistant".to_owned(),
|
|
Role::System => "system".to_owned(),
|
|
Role::Tool => "tool".to_owned(),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
|
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
|
|
pub struct Model {
|
|
pub name: String,
|
|
pub display_name: Option<String>,
|
|
pub max_tokens: u64,
|
|
pub supports_tool_calls: bool,
|
|
pub supports_images: bool,
|
|
}
|
|
|
|
impl Model {
|
|
pub fn new(
|
|
name: &str,
|
|
display_name: Option<&str>,
|
|
max_tokens: Option<u64>,
|
|
supports_tool_calls: bool,
|
|
supports_images: bool,
|
|
) -> Self {
|
|
Self {
|
|
name: name.to_owned(),
|
|
display_name: display_name.map(|s| s.to_owned()),
|
|
max_tokens: max_tokens.unwrap_or(2048),
|
|
supports_tool_calls,
|
|
supports_images,
|
|
}
|
|
}
|
|
|
|
pub fn id(&self) -> &str {
|
|
&self.name
|
|
}
|
|
|
|
pub fn display_name(&self) -> &str {
|
|
self.display_name.as_ref().unwrap_or(&self.name)
|
|
}
|
|
|
|
pub fn max_token_count(&self) -> u64 {
|
|
self.max_tokens
|
|
}
|
|
|
|
pub fn supports_tool_calls(&self) -> bool {
|
|
self.supports_tool_calls
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
#[serde(untagged)]
|
|
pub enum ToolChoice {
|
|
Auto,
|
|
Required,
|
|
None,
|
|
Other(ToolDefinition),
|
|
}
|
|
|
|
#[derive(Clone, Deserialize, Serialize, Debug)]
|
|
#[serde(tag = "type", rename_all = "snake_case")]
|
|
pub enum ToolDefinition {
|
|
#[allow(dead_code)]
|
|
Function { function: FunctionDefinition },
|
|
}
|
|
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
pub struct FunctionDefinition {
|
|
pub name: String,
|
|
pub description: Option<String>,
|
|
pub parameters: Option<Value>,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug)]
|
|
#[serde(tag = "role", rename_all = "lowercase")]
|
|
pub enum ChatMessage {
|
|
Assistant {
|
|
#[serde(default)]
|
|
content: Option<MessageContent>,
|
|
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
|
tool_calls: Vec<ToolCall>,
|
|
},
|
|
User {
|
|
content: MessageContent,
|
|
},
|
|
System {
|
|
content: MessageContent,
|
|
},
|
|
Tool {
|
|
content: MessageContent,
|
|
tool_call_id: String,
|
|
},
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
|
#[serde(untagged)]
|
|
pub enum MessageContent {
|
|
Plain(String),
|
|
Multipart(Vec<MessagePart>),
|
|
}
|
|
|
|
impl MessageContent {
|
|
pub fn empty() -> Self {
|
|
MessageContent::Multipart(vec![])
|
|
}
|
|
|
|
pub fn push_part(&mut self, part: MessagePart) {
|
|
match self {
|
|
MessageContent::Plain(text) => {
|
|
*self =
|
|
MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
|
|
}
|
|
MessageContent::Multipart(parts) if parts.is_empty() => match part {
|
|
MessagePart::Text { text } => *self = MessageContent::Plain(text),
|
|
MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
|
|
},
|
|
MessageContent::Multipart(parts) => parts.push(part),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl From<Vec<MessagePart>> for MessageContent {
|
|
fn from(mut parts: Vec<MessagePart>) -> Self {
|
|
if let [MessagePart::Text { text }] = parts.as_mut_slice() {
|
|
MessageContent::Plain(std::mem::take(text))
|
|
} else {
|
|
MessageContent::Multipart(parts)
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
|
#[serde(tag = "type", rename_all = "snake_case")]
|
|
pub enum MessagePart {
|
|
Text {
|
|
text: String,
|
|
},
|
|
#[serde(rename = "image_url")]
|
|
Image {
|
|
image_url: ImageUrl,
|
|
},
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
|
|
pub struct ImageUrl {
|
|
pub url: String,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub detail: Option<String>,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
|
pub struct ToolCall {
|
|
pub id: String,
|
|
#[serde(flatten)]
|
|
pub content: ToolCallContent,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
|
#[serde(tag = "type", rename_all = "lowercase")]
|
|
pub enum ToolCallContent {
|
|
Function { function: FunctionContent },
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
|
pub struct FunctionContent {
|
|
pub name: String,
|
|
pub arguments: String,
|
|
}
|
|
|
|
#[derive(Serialize, Debug)]
|
|
pub struct ChatCompletionRequest {
|
|
pub model: String,
|
|
pub messages: Vec<ChatMessage>,
|
|
pub stream: bool,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub max_tokens: Option<i32>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub stop: Option<Vec<String>>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub temperature: Option<f32>,
|
|
#[serde(skip_serializing_if = "Vec::is_empty")]
|
|
pub tools: Vec<ToolDefinition>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub tool_choice: Option<ToolChoice>,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug)]
|
|
pub struct ChatResponse {
|
|
pub id: String,
|
|
pub object: String,
|
|
pub created: u64,
|
|
pub model: String,
|
|
pub choices: Vec<ChoiceDelta>,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug)]
|
|
pub struct ChoiceDelta {
|
|
pub index: u32,
|
|
pub delta: ResponseMessageDelta,
|
|
pub finish_reason: Option<String>,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
|
pub struct ToolCallChunk {
|
|
pub index: usize,
|
|
pub id: Option<String>,
|
|
|
|
// There is also an optional `type` field that would determine if a
|
|
// function is there. Sometimes this streams in with the `function` before
|
|
// it streams in the `type`
|
|
pub function: Option<FunctionChunk>,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
|
pub struct FunctionChunk {
|
|
pub name: Option<String>,
|
|
pub arguments: Option<String>,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug)]
|
|
pub struct Usage {
|
|
pub prompt_tokens: u64,
|
|
pub completion_tokens: u64,
|
|
pub total_tokens: u64,
|
|
}
|
|
|
|
#[derive(Debug, Default, Clone, Deserialize, PartialEq)]
|
|
#[serde(transparent)]
|
|
pub struct Capabilities(Vec<String>);
|
|
|
|
impl Capabilities {
|
|
pub fn supports_tool_calls(&self) -> bool {
|
|
self.0.iter().any(|cap| cap == "tool_use")
|
|
}
|
|
|
|
pub fn supports_images(&self) -> bool {
|
|
self.0.iter().any(|cap| cap == "vision")
|
|
}
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug)]
|
|
pub struct LmStudioError {
|
|
pub message: String,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug)]
|
|
#[serde(untagged)]
|
|
pub enum ResponseStreamResult {
|
|
Ok(ResponseStreamEvent),
|
|
Err { error: LmStudioError },
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug)]
|
|
pub struct ResponseStreamEvent {
|
|
pub created: u32,
|
|
pub model: String,
|
|
pub object: String,
|
|
pub choices: Vec<ChoiceDelta>,
|
|
pub usage: Option<Usage>,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
pub struct ListModelsResponse {
|
|
pub data: Vec<ModelEntry>,
|
|
}
|
|
|
|
#[derive(Clone, Debug, Deserialize, PartialEq)]
|
|
pub struct ModelEntry {
|
|
pub id: String,
|
|
pub object: String,
|
|
pub r#type: ModelType,
|
|
pub publisher: String,
|
|
pub arch: Option<String>,
|
|
pub compatibility_type: CompatibilityType,
|
|
pub quantization: Option<String>,
|
|
pub state: ModelState,
|
|
pub max_context_length: Option<u64>,
|
|
pub loaded_context_length: Option<u64>,
|
|
#[serde(default)]
|
|
pub capabilities: Capabilities,
|
|
}
|
|
|
|
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
|
|
#[serde(rename_all = "lowercase")]
|
|
pub enum ModelType {
|
|
Llm,
|
|
Embeddings,
|
|
Vlm,
|
|
}
|
|
|
|
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
|
|
#[serde(rename_all = "kebab-case")]
|
|
pub enum ModelState {
|
|
Loaded,
|
|
Loading,
|
|
NotLoaded,
|
|
}
|
|
|
|
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
|
|
#[serde(rename_all = "lowercase")]
|
|
pub enum CompatibilityType {
|
|
Gguf,
|
|
Mlx,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
|
pub struct ResponseMessageDelta {
|
|
pub role: Option<Role>,
|
|
pub content: Option<String>,
|
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
pub reasoning_content: Option<String>,
|
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
pub tool_calls: Option<Vec<ToolCallChunk>>,
|
|
}
|
|
|
|
pub async fn complete(
|
|
client: &dyn HttpClient,
|
|
api_url: &str,
|
|
request: ChatCompletionRequest,
|
|
) -> Result<ChatResponse> {
|
|
let uri = format!("{api_url}/chat/completions");
|
|
let request_builder = HttpRequest::builder()
|
|
.method(Method::POST)
|
|
.uri(uri)
|
|
.header("Content-Type", "application/json");
|
|
|
|
let serialized_request = serde_json::to_string(&request)?;
|
|
let request = request_builder.body(AsyncBody::from(serialized_request))?;
|
|
|
|
let mut response = client.send(request).await?;
|
|
if response.status().is_success() {
|
|
let mut body = Vec::new();
|
|
response.body_mut().read_to_end(&mut body).await?;
|
|
let response_message: ChatResponse = serde_json::from_slice(&body)?;
|
|
Ok(response_message)
|
|
} else {
|
|
let mut body = Vec::new();
|
|
response.body_mut().read_to_end(&mut body).await?;
|
|
let body_str = std::str::from_utf8(&body)?;
|
|
anyhow::bail!(
|
|
"Failed to connect to API: {} {}",
|
|
response.status(),
|
|
body_str
|
|
);
|
|
}
|
|
}
|
|
|
|
pub async fn stream_chat_completion(
|
|
client: &dyn HttpClient,
|
|
api_url: &str,
|
|
request: ChatCompletionRequest,
|
|
) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
|
|
let uri = format!("{api_url}/chat/completions");
|
|
let request_builder = http::Request::builder()
|
|
.method(Method::POST)
|
|
.uri(uri)
|
|
.header("Content-Type", "application/json");
|
|
|
|
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
|
|
let mut response = client.send(request).await?;
|
|
if response.status().is_success() {
|
|
let reader = BufReader::new(response.into_body());
|
|
Ok(reader
|
|
.lines()
|
|
.filter_map(|line| async move {
|
|
match line {
|
|
Ok(line) => {
|
|
let line = line.strip_prefix("data: ")?;
|
|
if line == "[DONE]" {
|
|
None
|
|
} else {
|
|
match serde_json::from_str(line) {
|
|
Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
|
|
Ok(ResponseStreamResult::Err { error, .. }) => {
|
|
Some(Err(anyhow!(error.message)))
|
|
}
|
|
Err(error) => Some(Err(anyhow!(error))),
|
|
}
|
|
}
|
|
}
|
|
Err(error) => Some(Err(anyhow!(error))),
|
|
}
|
|
})
|
|
.boxed())
|
|
} else {
|
|
let mut body = String::new();
|
|
response.body_mut().read_to_string(&mut body).await?;
|
|
anyhow::bail!(
|
|
"Failed to connect to LM Studio API: {} {}",
|
|
response.status(),
|
|
body,
|
|
);
|
|
}
|
|
}
|
|
|
|
pub async fn get_models(
|
|
client: &dyn HttpClient,
|
|
api_url: &str,
|
|
_: Option<Duration>,
|
|
) -> Result<Vec<ModelEntry>> {
|
|
let uri = format!("{api_url}/models");
|
|
let request_builder = HttpRequest::builder()
|
|
.method(Method::GET)
|
|
.uri(uri)
|
|
.header("Accept", "application/json");
|
|
|
|
let request = request_builder.body(AsyncBody::default())?;
|
|
|
|
let mut response = client.send(request).await?;
|
|
|
|
let mut body = String::new();
|
|
response.body_mut().read_to_string(&mut body).await?;
|
|
|
|
anyhow::ensure!(
|
|
response.status().is_success(),
|
|
"Failed to connect to LM Studio API: {} {}",
|
|
response.status(),
|
|
body,
|
|
);
|
|
let response: ListModelsResponse =
|
|
serde_json::from_str(&body).context("Unable to parse LM Studio models response")?;
|
|
Ok(response.data)
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_image_message_part_serialization() {
|
|
let image_part = MessagePart::Image {
|
|
image_url: ImageUrl {
|
|
url: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg==".to_string(),
|
|
detail: None,
|
|
},
|
|
};
|
|
|
|
let json = serde_json::to_string(&image_part).unwrap();
|
|
println!("Serialized image part: {}", json);
|
|
|
|
// Verify the structure matches what LM Studio expects
|
|
let expected_structure = r#"{"type":"image_url","image_url":{"url":"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="}}"#;
|
|
assert_eq!(json, expected_structure);
|
|
}
|
|
|
|
#[test]
|
|
fn test_text_message_part_serialization() {
|
|
let text_part = MessagePart::Text {
|
|
text: "Hello, world!".to_string(),
|
|
};
|
|
|
|
let json = serde_json::to_string(&text_part).unwrap();
|
|
println!("Serialized text part: {}", json);
|
|
|
|
let expected_structure = r#"{"type":"text","text":"Hello, world!"}"#;
|
|
assert_eq!(json, expected_structure);
|
|
}
|
|
}
|