use anyhow::{Context as _, Result, anyhow}; use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http}; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::{convert::TryFrom, time::Duration}; pub const LMSTUDIO_API_URL: &str = "http://localhost:1234/api/v0"; #[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] #[serde(rename_all = "lowercase")] pub enum Role { User, Assistant, System, Tool, } impl TryFrom for Role { type Error = anyhow::Error; fn try_from(value: String) -> Result { match value.as_str() { "user" => Ok(Self::User), "assistant" => Ok(Self::Assistant), "system" => Ok(Self::System), "tool" => Ok(Self::Tool), _ => anyhow::bail!("invalid role '{value}'"), } } } impl From for String { fn from(val: Role) -> Self { match val { Role::User => "user".to_owned(), Role::Assistant => "assistant".to_owned(), Role::System => "system".to_owned(), Role::Tool => "tool".to_owned(), } } } #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] #[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] pub struct Model { pub name: String, pub display_name: Option, pub max_tokens: u64, pub supports_tool_calls: bool, pub supports_images: bool, } impl Model { pub fn new( name: &str, display_name: Option<&str>, max_tokens: Option, supports_tool_calls: bool, supports_images: bool, ) -> Self { Self { name: name.to_owned(), display_name: display_name.map(|s| s.to_owned()), max_tokens: max_tokens.unwrap_or(2048), supports_tool_calls, supports_images, } } pub fn id(&self) -> &str { &self.name } pub fn display_name(&self) -> &str { self.display_name.as_ref().unwrap_or(&self.name) } pub fn max_token_count(&self) -> u64 { self.max_tokens } pub fn supports_tool_calls(&self) -> bool { self.supports_tool_calls } } #[derive(Debug, Serialize, Deserialize)] #[serde(untagged)] pub enum ToolChoice { Auto, Required, None, Other(ToolDefinition), } #[derive(Clone, Deserialize, Serialize, Debug)] #[serde(tag = "type", rename_all = "snake_case")] pub enum ToolDefinition { #[allow(dead_code)] Function { function: FunctionDefinition }, } #[derive(Clone, Debug, Serialize, Deserialize)] pub struct FunctionDefinition { pub name: String, pub description: Option, pub parameters: Option, } #[derive(Serialize, Deserialize, Debug)] #[serde(tag = "role", rename_all = "lowercase")] pub enum ChatMessage { Assistant { #[serde(default)] content: Option, #[serde(default, skip_serializing_if = "Vec::is_empty")] tool_calls: Vec, }, User { content: MessageContent, }, System { content: MessageContent, }, Tool { content: MessageContent, tool_call_id: String, }, } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] #[serde(untagged)] pub enum MessageContent { Plain(String), Multipart(Vec), } impl MessageContent { pub fn empty() -> Self { MessageContent::Multipart(vec![]) } pub fn push_part(&mut self, part: MessagePart) { match self { MessageContent::Plain(text) => { *self = MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]); } MessageContent::Multipart(parts) if parts.is_empty() => match part { MessagePart::Text { text } => *self = MessageContent::Plain(text), MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]), }, MessageContent::Multipart(parts) => parts.push(part), } } } impl From> for MessageContent { fn from(mut parts: Vec) -> Self { if let [MessagePart::Text { text }] = parts.as_mut_slice() { MessageContent::Plain(std::mem::take(text)) } else { MessageContent::Multipart(parts) } } } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] #[serde(tag = "type", rename_all = "snake_case")] pub enum MessagePart { Text { text: String, }, #[serde(rename = "image_url")] Image { image_url: ImageUrl, }, } #[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] pub struct ImageUrl { pub url: String, #[serde(skip_serializing_if = "Option::is_none")] pub detail: Option, } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] pub struct ToolCall { pub id: String, #[serde(flatten)] pub content: ToolCallContent, } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] #[serde(tag = "type", rename_all = "lowercase")] pub enum ToolCallContent { Function { function: FunctionContent }, } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] pub struct FunctionContent { pub name: String, pub arguments: String, } #[derive(Serialize, Debug)] pub struct ChatCompletionRequest { pub model: String, pub messages: Vec, pub stream: bool, #[serde(skip_serializing_if = "Option::is_none")] pub max_tokens: Option, #[serde(skip_serializing_if = "Option::is_none")] pub stop: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub temperature: Option, #[serde(skip_serializing_if = "Vec::is_empty")] pub tools: Vec, #[serde(skip_serializing_if = "Option::is_none")] pub tool_choice: Option, } #[derive(Serialize, Deserialize, Debug)] pub struct ChatResponse { pub id: String, pub object: String, pub created: u64, pub model: String, pub choices: Vec, } #[derive(Serialize, Deserialize, Debug)] pub struct ChoiceDelta { pub index: u32, pub delta: ResponseMessageDelta, pub finish_reason: Option, } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] pub struct ToolCallChunk { pub index: usize, pub id: Option, // There is also an optional `type` field that would determine if a // function is there. Sometimes this streams in with the `function` before // it streams in the `type` pub function: Option, } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] pub struct FunctionChunk { pub name: Option, pub arguments: Option, } #[derive(Serialize, Deserialize, Debug)] pub struct Usage { pub prompt_tokens: u64, pub completion_tokens: u64, pub total_tokens: u64, } #[derive(Debug, Default, Clone, Deserialize, PartialEq)] #[serde(transparent)] pub struct Capabilities(Vec); impl Capabilities { pub fn supports_tool_calls(&self) -> bool { self.0.iter().any(|cap| cap == "tool_use") } pub fn supports_images(&self) -> bool { self.0.iter().any(|cap| cap == "vision") } } #[derive(Serialize, Deserialize, Debug)] pub struct LmStudioError { pub message: String, } #[derive(Serialize, Deserialize, Debug)] #[serde(untagged)] pub enum ResponseStreamResult { Ok(ResponseStreamEvent), Err { error: LmStudioError }, } #[derive(Serialize, Deserialize, Debug)] pub struct ResponseStreamEvent { pub created: u32, pub model: String, pub object: String, pub choices: Vec, pub usage: Option, } #[derive(Deserialize)] pub struct ListModelsResponse { pub data: Vec, } #[derive(Clone, Debug, Deserialize, PartialEq)] pub struct ModelEntry { pub id: String, pub object: String, pub r#type: ModelType, pub publisher: String, pub arch: Option, pub compatibility_type: CompatibilityType, pub quantization: Option, pub state: ModelState, pub max_context_length: Option, pub loaded_context_length: Option, #[serde(default)] pub capabilities: Capabilities, } #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] #[serde(rename_all = "lowercase")] pub enum ModelType { Llm, Embeddings, Vlm, } #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] #[serde(rename_all = "kebab-case")] pub enum ModelState { Loaded, Loading, NotLoaded, } #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] #[serde(rename_all = "lowercase")] pub enum CompatibilityType { Gguf, Mlx, } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] pub struct ResponseMessageDelta { pub role: Option, pub content: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub reasoning_content: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub tool_calls: Option>, } pub async fn complete( client: &dyn HttpClient, api_url: &str, request: ChatCompletionRequest, ) -> Result { let uri = format!("{api_url}/chat/completions"); let request_builder = HttpRequest::builder() .method(Method::POST) .uri(uri) .header("Content-Type", "application/json"); let serialized_request = serde_json::to_string(&request)?; let request = request_builder.body(AsyncBody::from(serialized_request))?; let mut response = client.send(request).await?; if response.status().is_success() { let mut body = Vec::new(); response.body_mut().read_to_end(&mut body).await?; let response_message: ChatResponse = serde_json::from_slice(&body)?; Ok(response_message) } else { let mut body = Vec::new(); response.body_mut().read_to_end(&mut body).await?; let body_str = std::str::from_utf8(&body)?; anyhow::bail!( "Failed to connect to API: {} {}", response.status(), body_str ); } } pub async fn stream_chat_completion( client: &dyn HttpClient, api_url: &str, request: ChatCompletionRequest, ) -> Result>> { let uri = format!("{api_url}/chat/completions"); let request_builder = http::Request::builder() .method(Method::POST) .uri(uri) .header("Content-Type", "application/json"); let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; let mut response = client.send(request).await?; if response.status().is_success() { let reader = BufReader::new(response.into_body()); Ok(reader .lines() .filter_map(|line| async move { match line { Ok(line) => { let line = line.strip_prefix("data: ")?; if line == "[DONE]" { None } else { match serde_json::from_str(line) { Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)), Ok(ResponseStreamResult::Err { error, .. }) => { Some(Err(anyhow!(error.message))) } Err(error) => Some(Err(anyhow!(error))), } } } Err(error) => Some(Err(anyhow!(error))), } }) .boxed()) } else { let mut body = String::new(); response.body_mut().read_to_string(&mut body).await?; anyhow::bail!( "Failed to connect to LM Studio API: {} {}", response.status(), body, ); } } pub async fn get_models( client: &dyn HttpClient, api_url: &str, _: Option, ) -> Result> { let uri = format!("{api_url}/models"); let request_builder = HttpRequest::builder() .method(Method::GET) .uri(uri) .header("Accept", "application/json"); let request = request_builder.body(AsyncBody::default())?; let mut response = client.send(request).await?; let mut body = String::new(); response.body_mut().read_to_string(&mut body).await?; anyhow::ensure!( response.status().is_success(), "Failed to connect to LM Studio API: {} {}", response.status(), body, ); let response: ListModelsResponse = serde_json::from_str(&body).context("Unable to parse LM Studio models response")?; Ok(response.data) } #[cfg(test)] mod tests { use super::*; #[test] fn test_image_message_part_serialization() { let image_part = MessagePart::Image { image_url: ImageUrl { url: "".to_string(), detail: None, }, }; let json = serde_json::to_string(&image_part).unwrap(); println!("Serialized image part: {}", json); // Verify the structure matches what LM Studio expects let expected_structure = r#"{"type":"image_url","image_url":{"url":""}}"#; assert_eq!(json, expected_structure); } #[test] fn test_text_message_part_serialization() { let text_part = MessagePart::Text { text: "Hello, world!".to_string(), }; let json = serde_json::to_string(&text_part).unwrap(); println!("Serialized text part: {}", json); let expected_structure = r#"{"type":"text","text":"Hello, world!"}"#; assert_eq!(json, expected_structure); } }