Skip to content

feat(text_splitters): Add SemanticChunker splitter #7855

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
import { BaseDocumentTransformer, Document } from "@langchain/core/documents"
import { Embeddings } from "@langchain/core/embeddings"

type Sentence = {
sentence: string
combined_sentence_embedding?: number[]
combined_sentence?: string
distance_to_next?: number
}

type SentenceWithCombinedSentence = Sentence & { combined_sentence: string }
type SentenceWithEmbedding = Sentence & { combined_sentence_embedding: number[]; distance_to_next: number }

function combineSentences(sentences: Sentence[], sentecesToCombine: number = 1) {
/**
* Combine sentences based on buffer size.
*
* @param sentences - List of sentences to combine.
* @param sentecesToCombine - Number of sentences to combine. Defaults to 1.
* @returns List of sentences with combined sentences.
*/

for (let i = 0; i < sentences.length; i++) {
let combinedSentence = ""

// Add sentences before the current one
for (let j = i - sentecesToCombine; j < i; j++) {
if (j >= 0) {
combinedSentence += sentences[j].sentence + " "
}
}

// Add the current sentence
combinedSentence += sentences[i].sentence

// Add sentences after the current one
for (let j = i + 1; j < i + 1 + sentecesToCombine; j++) {
if (j < sentences.length) {
combinedSentence += " " + sentences[j].sentence
}
}

// Store the combined sentence in the current sentence object
sentences[i].combined_sentence = combinedSentence
}

return sentences as SentenceWithCombinedSentence[]
}

function calculateCosineDistances(sentences: SentenceWithEmbedding[]): [number[], SentenceWithEmbedding[]] {
/**
* Calculate cosine distances between sentences.
*
* @param sentences - List of sentences to calculate distances for.
* @returns Tuple of distances and sentences.
*/

const distances: number[] = []

for (let i = 0; i < sentences.length - 1; i++) {
const embeddingCurrent = sentences[i].combined_sentence_embedding
const embeddingNext = sentences[i + 1].combined_sentence_embedding

// Calculate cosine similarity
const similarity = cosineSimilarity(embeddingCurrent, embeddingNext)

// Convert to cosine distance
const distance = 1 - similarity

// Append cosine distance to the list
distances.push(distance)

// Store distance in the dictionary
sentences[i].distance_to_next = distance
}

return [distances, sentences]
}

// Utility function for cosine similarity
function cosineSimilarity(vecA: number[], vecB: number[]) {
const dotProduct = vecA.reduce((sum, a, i) => sum + a * vecB[i], 0)
const magnitudeA = Math.sqrt(vecA.reduce((sum, a) => sum + a * a, 0))
const magnitudeB = Math.sqrt(vecB.reduce((sum, b) => sum + b * b, 0))

return magnitudeA && magnitudeB ? dotProduct / (magnitudeA * magnitudeB) : 0
}

enum BreakpointThresholdType {
PERCENTILE = "percentile",
STANDARD_DEVIATION = "standard_deviation",
INTERQUARTILE = "interquartile",
GRADIENT = "gradient",
}

const BREAKPOINT_DEFAULTS: Record<BreakpointThresholdType, number> = {
percentile: 95,
standard_deviation: 3,
interquartile: 1.5,
gradient: 95,
}

interface SemanticChunkerOptions {
sentecesToCombine?: number
sentenceSplitRegex?: RegExp
addStartIndex?: boolean
breakpointThresholdType?: BreakpointThresholdType
breakpointThresholdAmount?: number
numberOfChunks?: number
minChunkSize?: number
}

export class SemanticChunker extends BaseDocumentTransformer {
private sentecesToCombine: number = 1
private sentenceSplitRegex: RegExp = new RegExp(`(?<=[.?!])\\s+`)
private addStartIndex: boolean = false
private breakpointThresholdType: BreakpointThresholdType = BreakpointThresholdType.PERCENTILE
private breakpointThresholdAmount: number = BREAKPOINT_DEFAULTS[this.breakpointThresholdType]
private numberOfChunks?: number
private minChunkSize?: number

constructor(
private embeddings: Embeddings,
options?: SemanticChunkerOptions
) {
super()

if (options) {
if (options.sentecesToCombine !== undefined) this.sentecesToCombine = options.sentecesToCombine
if (options.sentenceSplitRegex !== undefined) this.sentenceSplitRegex = new RegExp(options.sentenceSplitRegex)
if (options.addStartIndex !== undefined) this.addStartIndex = options.addStartIndex
if (options.breakpointThresholdType !== undefined) {
this.breakpointThresholdType = options.breakpointThresholdType
this.breakpointThresholdAmount = BREAKPOINT_DEFAULTS[options.breakpointThresholdType]
}
if (options.breakpointThresholdAmount !== undefined) this.breakpointThresholdAmount = options.breakpointThresholdAmount
if (options.numberOfChunks !== undefined) this.numberOfChunks = options.numberOfChunks
if (options.minChunkSize !== undefined) this.minChunkSize = options.minChunkSize
}
}

// Utility functions
private mean(arr: number[]): number {
return arr.reduce((a, b) => a + b, 0) / arr.length
}

private std(arr: number[]): number {
const mean = this.mean(arr)
return Math.sqrt(arr.map((x) => (x - mean) ** 2).reduce((a, b) => a + b, 0) / arr.length)
}

private percentile(arr: number[], p: number): number {
const sorted = [...arr].sort((a, b) => a - b)
const index = Math.ceil((p / 100) * sorted.length) - 1
return sorted[Math.max(index, 0)]
}

private percentileValues(arr: number[], percentiles: number[]): number[] {
return percentiles.map((p) => this.percentile(arr, p))
}

private gradient(arr: number[]): number[] {
return arr.map((_, i, a) => (i === 0 ? a[i + 1] - a[i] : a[i] - a[i - 1])).slice(1)
}

private _calculateBreakpointThreshold(distances: number[]): [number, number[]] {
switch (this.breakpointThresholdType) {
case BreakpointThresholdType.PERCENTILE:
return [this.percentile(distances, this.breakpointThresholdAmount), distances]

case BreakpointThresholdType.STANDARD_DEVIATION:
return [this.mean(distances) + this.breakpointThresholdAmount * this.std(distances), distances]

case BreakpointThresholdType.INTERQUARTILE:
const [q1, q3] = this.percentileValues(distances, [25, 75])
const iqr = q3 - q1
return [this.mean(distances) + this.breakpointThresholdAmount * iqr, distances]

case BreakpointThresholdType.GRADIENT:
const distanceGradient = this.gradient(distances)
return [this.percentile(distanceGradient, this.breakpointThresholdAmount), distanceGradient]

default:
throw new Error(`Unexpected breakpointThresholdType: ${this.breakpointThresholdType}`)
}
}

private _thresholdFromClusters(distances: number[]): number {
if (this.numberOfChunks === undefined) {
throw new Error("This should never be called if `numberOfChunks` is undefined.")
}

const x1 = distances.length
const y1 = 0.0
const x2 = 1.0
const y2 = 100.0

const x = Math.max(Math.min(this.numberOfChunks, x1), x2)
const y = x2 === x1 ? y2 : y1 + ((y2 - y1) / (x2 - x1)) * (x - x1)

return this.percentile(distances, Math.min(Math.max(y, 0), 100))
}

private async _calculateSentenceDistances(singleSentencesList: string[]): Promise<[number[], any[]]> {
const sentences = combineSentences(
singleSentencesList.map((sentence) => ({ sentence })),
this.sentecesToCombine
)

const embeddings = await this.embeddings.embedDocuments(sentences.map((x) => x.combined_sentence))

sentences.forEach((sentence, i) => {
sentence.combined_sentence_embedding = embeddings[i]
})

return calculateCosineDistances(sentences as SentenceWithEmbedding[])
}

private _getSingleSentencesList(text: string): string[] {
return text.split(this.sentenceSplitRegex)
}

async splitText(text: string): Promise<string[]> {
const singleSentencesList = this._getSingleSentencesList(text)

if (singleSentencesList.length <= 1) return singleSentencesList

if (this.breakpointThresholdType === BreakpointThresholdType.GRADIENT && singleSentencesList.length === 2) {
return singleSentencesList
}

const [distances, sentences] = await this._calculateSentenceDistances(singleSentencesList)
let breakpointDistanceThreshold: number
let breakpointArray: number[]

if (this.numberOfChunks !== undefined) {
breakpointDistanceThreshold = this._thresholdFromClusters(distances)
breakpointArray = distances
} else {
;[breakpointDistanceThreshold, breakpointArray] = this._calculateBreakpointThreshold(distances)
}

const indicesAboveThresh = breakpointArray.map((x, i) => (x > breakpointDistanceThreshold ? i : -1)).filter((i) => i !== -1)

const chunks: string[] = []
let startIndex = 0

for (const index of indicesAboveThresh) {
const endIndex = index
const group = sentences.slice(startIndex, endIndex + 1)
const combinedText = group.map((d) => d.sentence).join(" ")

if (this.minChunkSize !== undefined && combinedText.length < this.minChunkSize) continue
chunks.push(combinedText)
startIndex = index + 1
}

if (startIndex < sentences.length) {
chunks.push(
sentences
.slice(startIndex)
.map((d) => d.sentence)
.join(" ")
)
}

return chunks
}

async createDocuments(texts: string[], metadatas?: Record<string, any>[]): Promise<Document[]> {
const _metadatas = metadatas || Array(texts.length).fill({})
const documents: Document[] = []

for (const [i, text] of texts.entries()) {
let startIndex = 0
const chunks = await this.splitText(text)

chunks.forEach((chunk) => {
const metadata = { ..._metadatas[i] }

if (this.addStartIndex) metadata.start_index = startIndex

documents.push(new Document({ pageContent: chunk, metadata }))
startIndex += chunk.length
})
}

return documents
}

async splitDocuments(documents: Document[]): Promise<Document[]> {
const texts: string[] = []
const metadatas: Record<string, any>[] = []

for (const doc of documents) {
texts.push(doc.pageContent)
metadatas.push(doc.metadata)
}

return this.createDocuments(texts, metadatas)
}

async transformDocuments(documents: Document[]): Promise<Document[]> {
return this.splitDocuments(documents)
}
}