Reify Embedding
/Sha1
structs that can be (de)serialized from SQL
Co-Authored-By: Kyle Caverly <kyle@zed.dev>
This commit is contained in:
parent
c763e728d1
commit
3001a46f69
5 changed files with 180 additions and 138 deletions
|
@ -1,7 +1,11 @@
|
|||
use crate::embedding::EmbeddingProvider;
|
||||
use anyhow::{anyhow, Ok, Result};
|
||||
use crate::embedding::{EmbeddingProvider, Embedding};
|
||||
use anyhow::{anyhow, Result};
|
||||
use language::{Grammar, Language};
|
||||
use sha1::{Digest, Sha1};
|
||||
use rusqlite::{
|
||||
types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
|
||||
ToSql,
|
||||
};
|
||||
use sha1::Digest;
|
||||
use std::{
|
||||
cmp::{self, Reverse},
|
||||
collections::HashSet,
|
||||
|
@ -11,13 +15,43 @@ use std::{
|
|||
};
|
||||
use tree_sitter::{Parser, QueryCursor};
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct Sha1([u8; 20]);
|
||||
|
||||
impl FromSql for Sha1 {
|
||||
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
|
||||
let blob = value.as_blob()?;
|
||||
let bytes =
|
||||
blob.try_into()
|
||||
.map_err(|_| rusqlite::types::FromSqlError::InvalidBlobSize {
|
||||
expected_size: 20,
|
||||
blob_size: blob.len(),
|
||||
})?;
|
||||
return Ok(Sha1(bytes));
|
||||
}
|
||||
}
|
||||
|
||||
impl ToSql for Sha1 {
|
||||
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
|
||||
self.0.to_sql()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&'_ str> for Sha1 {
|
||||
fn from(value: &'_ str) -> Self {
|
||||
let mut sha1 = sha1::Sha1::new();
|
||||
sha1.update(value);
|
||||
Self(sha1.finalize().into())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct Document {
|
||||
pub name: String,
|
||||
pub range: Range<usize>,
|
||||
pub content: String,
|
||||
pub embedding: Vec<f32>,
|
||||
pub sha1: [u8; 20],
|
||||
pub embedding: Option<Embedding>,
|
||||
pub sha1: Sha1,
|
||||
pub token_count: usize,
|
||||
}
|
||||
|
||||
|
@ -69,17 +103,16 @@ impl CodeContextRetriever {
|
|||
.replace("<language>", language_name.as_ref())
|
||||
.replace("<item>", &content);
|
||||
|
||||
let mut sha1 = Sha1::new();
|
||||
sha1.update(&document_span);
|
||||
let sha1 = Sha1::from(document_span.as_str());
|
||||
|
||||
let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
|
||||
|
||||
Ok(vec![Document {
|
||||
range: 0..content.len(),
|
||||
content: document_span,
|
||||
embedding: Vec::new(),
|
||||
embedding: Default::default(),
|
||||
name: language_name.to_string(),
|
||||
sha1: sha1.finalize().into(),
|
||||
sha1,
|
||||
token_count,
|
||||
}])
|
||||
}
|
||||
|
@ -88,18 +121,14 @@ impl CodeContextRetriever {
|
|||
let document_span = MARKDOWN_CONTEXT_TEMPLATE
|
||||
.replace("<path>", relative_path.to_string_lossy().as_ref())
|
||||
.replace("<item>", &content);
|
||||
|
||||
let mut sha1 = Sha1::new();
|
||||
sha1.update(&document_span);
|
||||
|
||||
let sha1 = Sha1::from(document_span.as_str());
|
||||
let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
|
||||
|
||||
Ok(vec![Document {
|
||||
range: 0..content.len(),
|
||||
content: document_span,
|
||||
embedding: Vec::new(),
|
||||
embedding: None,
|
||||
name: "Markdown".to_string(),
|
||||
sha1: sha1.finalize().into(),
|
||||
sha1,
|
||||
token_count,
|
||||
}])
|
||||
}
|
||||
|
@ -279,15 +308,13 @@ impl CodeContextRetriever {
|
|||
);
|
||||
}
|
||||
|
||||
let mut sha1 = Sha1::new();
|
||||
sha1.update(&document_content);
|
||||
|
||||
let sha1 = Sha1::from(document_content.as_str());
|
||||
documents.push(Document {
|
||||
name,
|
||||
content: document_content,
|
||||
range: item_range.clone(),
|
||||
embedding: vec![],
|
||||
sha1: sha1.finalize().into(),
|
||||
embedding: None,
|
||||
sha1,
|
||||
token_count: 0,
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue