ZIm/crates/semantic_index/src/embedding/ollama.rs
2025-03-31 20:55:27 +02:00

74 lines
2 KiB
Rust

use anyhow::{Context as _, Result};
use futures::{AsyncReadExt as _, FutureExt, future::BoxFuture};
use http_client::HttpClient;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use crate::{Embedding, EmbeddingProvider, TextToEmbed};
pub enum OllamaEmbeddingModel {
NomicEmbedText,
MxbaiEmbedLarge,
}
pub struct OllamaEmbeddingProvider {
client: Arc<dyn HttpClient>,
model: OllamaEmbeddingModel,
}
#[derive(Serialize)]
struct OllamaEmbeddingRequest {
model: String,
prompt: String,
}
#[derive(Deserialize)]
struct OllamaEmbeddingResponse {
embedding: Vec<f32>,
}
impl OllamaEmbeddingProvider {
pub fn new(client: Arc<dyn HttpClient>, model: OllamaEmbeddingModel) -> Self {
Self { client, model }
}
}
impl EmbeddingProvider for OllamaEmbeddingProvider {
fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>> {
//
let model = match self.model {
OllamaEmbeddingModel::NomicEmbedText => "nomic-embed-text",
OllamaEmbeddingModel::MxbaiEmbedLarge => "mxbai-embed-large",
};
futures::future::try_join_all(texts.iter().map(|to_embed| {
let request = OllamaEmbeddingRequest {
model: model.to_string(),
prompt: to_embed.text.to_string(),
};
let request = serde_json::to_string(&request).unwrap();
async {
let response = self
.client
.post_json("http://localhost:11434/api/embeddings", request.into())
.await?;
let mut body = String::new();
response.into_body().read_to_string(&mut body).await?;
let response: OllamaEmbeddingResponse =
serde_json::from_str(&body).context("Unable to pull response")?;
Ok(Embedding::new(response.embedding))
}
}))
.boxed()
}
fn batch_size(&self) -> usize {
// TODO: Figure out decent value
10
}
}