Open
Description
Description
Taking inspiration from the LlamaEmbedder and the multimodal support which has been added to LlamaInteractExecutor, I have been trying to implement a multimodal embedder. The main idea is to support Qwen2-VL related models specialized in screenshot embedding such as:
IMO, it should works as I did i.e.:
- building manually the prompt
- tokenizing separately the prompt before and after the image marker
- feeding the model with the token before the image, then the image using
LlavaWeights.EvalImageEmbed
and finally the tokens after the image - getting the embedding of the last token
<|endoftext|>
and normalize it
But:
- I don't get the same vectors as the one I obtain with python code
- even worse, embedding twice the same image doesn't give me the same vectors
- and even using two different context instances, I don't get the same vector
Does it ring a bell to someone?
Here is my class and a dummy unit test comparing two runs of image embedding computation
using System.Numerics.Tensors;
using System.Text;
using LLama.Common;
using LLama.Extensions;
using LLama.Native;
namespace LLama.Unittest;
public sealed class LlamaMultimodalEmbedder : IDisposable
{
private readonly LLavaWeights _llavaWeights;
private readonly LLamaContext _context;
public LlamaMultimodalEmbedder(LLamaContext context, LLavaWeights llavaWeights)
{
if (context.Params.UBatchSize != context.Params.BatchSize)
throw new ArgumentException("For non-causal models, batch size must be equal to ubatch size");
_llavaWeights = llavaWeights;
_context = context;
NativeApi.llama_set_embeddings(_context.NativeHandle, true);
}
private bool _disposed;
public void Dispose()
{
if (_disposed)
return;
_context.Dispose();
_llavaWeights.Dispose();
_disposed = true;
}
private const string ImageMarker = "<|image_pad|>";
private readonly int _imageMarkerSize = ImageMarker.Length;
private async Task<float[]> GetEmbedding(
string? text,
byte[]? image,
CancellationToken cancellationToken = default)
{
// clear previous kv_cache values
_context.NativeHandle.KvCacheClear();
_context.NativeHandle.KvCacheRemove( LLamaSeqId.Zero, -1, -1 );
var hasText = !string.IsNullOrEmpty(text);
var hasImage = image != null;
if (!hasText && !hasImage)
throw new ArgumentException("At least one of text or image must be provided");
// Even if it implies a loss of genericity, we build manually the prompt for two reasons:
// * history doesn't handle image content
// * we aim to support Qwen2-VL like model
var promptBuilder = new StringBuilder();
promptBuilder
.Append("<|im_start|>system\n")
.Append("You are a helpful assistant.<|im_end|>\n");
promptBuilder.Append("<|im_start|>user\n");
if (hasImage)
promptBuilder.Append("<|vision_start|>").Append(ImageMarker).Append("<|vision_end|>");
if (hasText)
promptBuilder.Append(text);
promptBuilder.Append("<|im_end|>\n");
promptBuilder
.Append("<|im_start|>assistant\n")
.Append("<|endoftext|>");
var prompt = promptBuilder.ToString();
// Compute embeddings of the input image to be fed into the model
using var imageEmbeddingHandle = hasImage ? GetImageEmbeddingHandle(image!) : null;
var tokens = new List<LLamaToken>();
var imageTokenIndex = -1;
if (hasImage)
{
var imageIndexInPrompt = prompt.IndexOf(ImageMarker, StringComparison.Ordinal);
// Tokenize text segment before <|image_pad|> tag
var promptBeforeImage = prompt[..imageIndexInPrompt];
var tokensBeforeImage = _context.Tokenize(promptBeforeImage, addBos: true, special: true);
// Remember the position to add the image embeddings
imageTokenIndex = tokensBeforeImage.Length;
// Tokenize text segment after <|image_pad|> tag
var promptAfterImage = prompt[(imageIndexInPrompt + _imageMarkerSize)..];
var tokensAfterImage = _context.Tokenize(promptAfterImage, addBos: false, special: true);
tokens.AddRange(tokensBeforeImage);
tokens.AddRange(tokensAfterImage);
}
else
{
tokens.AddRange(_context.Tokenize(prompt, addBos: true, special: true));
}
var tokensCount = tokens.Count;
if (tokensCount > _context.ContextSize)
throw new ArgumentException(
$"Embedding prompt is longer than the context window ({tokensCount} > {_context.ContextSize})");
// Check if we should cancel the work, just before doing anything expensive (encode/decode)
cancellationToken.ThrowIfCancellationRequested();
// Evaluate prompt in batch-size chunks
var batch = new LLamaBatch();
var nPast = 0;
var decodeResponse = await _context
.DecodeAsync(tokens.GetRange(0, hasImage ? imageTokenIndex : tokensCount), LLamaSeqId.Zero, batch, nPast)
.ConfigureAwait(false);
nPast = decodeResponse.Item3;
if (hasImage)
{
_llavaWeights.EvalImageEmbed(_context, imageEmbeddingHandle!, ref nPast);
decodeResponse = await _context
.DecodeAsync(tokens.GetRange(imageTokenIndex, tokensCount - imageTokenIndex), LLamaSeqId.Zero, batch,
nPast)
.ConfigureAwait(false);
nPast = decodeResponse.Item3;
}
var poolingType = _context.NativeHandle.PoolingType;
if (poolingType != LLamaPoolingType.None)
throw new NotSupportedException("Unsupported pooling type");
var positions = batch.GetLogitPositions();
if (positions == null)
throw new InvalidOperationException("GetLogitPositions returned null");
var embedding = _context.NativeHandle.GetEmbeddingsIth(positions[^1].Item2).ToArray();
embedding.EuclideanNormalization();
return embedding;
}
private SafeLlavaImageEmbedHandle GetImageEmbeddingHandle(byte[] imageBytes)
{
if (_llavaWeights == null)
throw new InvalidOperationException("LLavaWeights is not loaded.");
var embeddingsHandle = _llavaWeights.CreateImageEmbeddings(imageBytes);
if (embeddingsHandle.IsInvalid)
throw new InvalidOperationException(
"Failed to create embedding handle, make sure that the image is a valid base 64 encoded string.");
return embeddingsHandle;
}
public async Task<float[]> GetTextEmbedding(string text, CancellationToken cancellationToken) =>
await GetEmbedding(text, null, cancellationToken).ConfigureAwait(false);
public async Task<float[]> GetImageEmbedding(byte[] imageBytes, CancellationToken cancellationToken) =>
await GetEmbedding(null, imageBytes, cancellationToken).ConfigureAwait(false);
}
public sealed class LLamaMultimodalEmbedderTests
{
private const string ModelPath = "path\to\model.gguf";
private const string MmprojPath = "path\to\mmproj.gguf";
private const string ImagePath = "path\to\image.png";
[Fact]
public async Task TestBasic()
{
var parameters = new ModelParams(ModelPath)
{
GpuLayerCount = 5
};
var model = await LLamaWeights.LoadFromFileAsync(parameters);
var llavaWeights = await LLavaWeights.LoadFromFileAsync(MmprojPath);
var context = model.CreateContext(parameters);
var multimodalEmbedder = new LlamaMultimodalEmbedder(context, llavaWeights);
var embedding1 = await multimodalEmbedder.GetImageEmbedding(
await File.ReadAllBytesAsync(ImagePath),
CancellationToken.None);
var embedding2 = await multimodalEmbedder.GetImageEmbedding(
await File.ReadAllBytesAsync(ImagePath),
CancellationToken.None);
var diff = TensorPrimitives.Norm(
embedding1.Zip(embedding2, (a, b) => a - b).ToArray());
Assert.True(diff < 10e-1);
}
}
Metadata
Metadata
Assignees
Labels
No labels