Skip to content

Multimodal embedding #1193

Open
Open
@LoicDagnas

Description

@LoicDagnas

Description

Taking inspiration from the LlamaEmbedder and the multimodal support which has been added to LlamaInteractExecutor, I have been trying to implement a multimodal embedder. The main idea is to support Qwen2-VL related models specialized in screenshot embedding such as:

IMO, it should works as I did i.e.:

  • building manually the prompt
  • tokenizing separately the prompt before and after the image marker
  • feeding the model with the token before the image, then the image using LlavaWeights.EvalImageEmbed and finally the tokens after the image
  • getting the embedding of the last token <|endoftext|> and normalize it

But:

  • I don't get the same vectors as the one I obtain with python code
  • even worse, embedding twice the same image doesn't give me the same vectors
  • and even using two different context instances, I don't get the same vector

Does it ring a bell to someone?

Here is my class and a dummy unit test comparing two runs of image embedding computation

using System.Numerics.Tensors;
using System.Text;
using LLama.Common;
using LLama.Extensions;
using LLama.Native;

namespace LLama.Unittest;

public sealed class LlamaMultimodalEmbedder : IDisposable
{
    private readonly LLavaWeights _llavaWeights;
    private readonly LLamaContext _context;

    public LlamaMultimodalEmbedder(LLamaContext context, LLavaWeights llavaWeights)
    {
        if (context.Params.UBatchSize != context.Params.BatchSize)
            throw new ArgumentException("For non-causal models, batch size must be equal to ubatch size");

        _llavaWeights = llavaWeights;
        _context = context;

        NativeApi.llama_set_embeddings(_context.NativeHandle, true);
    }

    private bool _disposed;

    public void Dispose()
    {
        if (_disposed)
            return;

        _context.Dispose();
        _llavaWeights.Dispose();
        _disposed = true;
    }

    private const string ImageMarker = "<|image_pad|>";
    private readonly int _imageMarkerSize = ImageMarker.Length;

    private async Task<float[]> GetEmbedding(
        string? text,
        byte[]? image,
        CancellationToken cancellationToken = default)
    {
        // clear previous kv_cache values
        _context.NativeHandle.KvCacheClear();
        _context.NativeHandle.KvCacheRemove( LLamaSeqId.Zero, -1, -1 );

        var hasText = !string.IsNullOrEmpty(text);
        var hasImage = image != null;

        if (!hasText && !hasImage)
            throw new ArgumentException("At least one of text or image must be provided");

        // Even if it implies a loss of genericity, we build manually the prompt for two reasons:
        // * history doesn't handle image content
        // * we aim to support Qwen2-VL like model
        var promptBuilder = new StringBuilder();

        promptBuilder
            .Append("<|im_start|>system\n")
            .Append("You are a helpful assistant.<|im_end|>\n");

        promptBuilder.Append("<|im_start|>user\n");

        if (hasImage)
            promptBuilder.Append("<|vision_start|>").Append(ImageMarker).Append("<|vision_end|>");

        if (hasText)
            promptBuilder.Append(text);

        promptBuilder.Append("<|im_end|>\n");

        promptBuilder
            .Append("<|im_start|>assistant\n")
            .Append("<|endoftext|>");

        var prompt = promptBuilder.ToString();

        // Compute embeddings of the input image to be fed into the model
        using var imageEmbeddingHandle = hasImage ? GetImageEmbeddingHandle(image!) : null;

        var tokens = new List<LLamaToken>();
        var imageTokenIndex = -1;

        if (hasImage)
        {
            var imageIndexInPrompt = prompt.IndexOf(ImageMarker, StringComparison.Ordinal);

            // Tokenize text segment before <|image_pad|> tag
            var promptBeforeImage = prompt[..imageIndexInPrompt];
            var tokensBeforeImage = _context.Tokenize(promptBeforeImage, addBos: true, special: true);

            // Remember the position to add the image embeddings
            imageTokenIndex = tokensBeforeImage.Length;

            // Tokenize text segment after <|image_pad|> tag
            var promptAfterImage = prompt[(imageIndexInPrompt + _imageMarkerSize)..];
            var tokensAfterImage = _context.Tokenize(promptAfterImage, addBos: false, special: true);

            tokens.AddRange(tokensBeforeImage);
            tokens.AddRange(tokensAfterImage);
        }
        else
        {
            tokens.AddRange(_context.Tokenize(prompt, addBos: true, special: true));
        }

        var tokensCount = tokens.Count;

        if (tokensCount > _context.ContextSize)
            throw new ArgumentException(
                $"Embedding prompt is longer than the context window ({tokensCount} > {_context.ContextSize})");

        // Check if we should cancel the work, just before doing anything expensive (encode/decode)
        cancellationToken.ThrowIfCancellationRequested();

        // Evaluate prompt in batch-size chunks
        var batch = new LLamaBatch();
        var nPast = 0;

        var decodeResponse = await _context
            .DecodeAsync(tokens.GetRange(0, hasImage ? imageTokenIndex : tokensCount), LLamaSeqId.Zero, batch, nPast)
            .ConfigureAwait(false);

        nPast = decodeResponse.Item3;

        if (hasImage)
        {
            _llavaWeights.EvalImageEmbed(_context, imageEmbeddingHandle!, ref nPast);

            decodeResponse = await _context
                .DecodeAsync(tokens.GetRange(imageTokenIndex, tokensCount - imageTokenIndex), LLamaSeqId.Zero, batch,
                    nPast)
                .ConfigureAwait(false);

            nPast = decodeResponse.Item3;
        }

        var poolingType = _context.NativeHandle.PoolingType;

        if (poolingType != LLamaPoolingType.None)
            throw new NotSupportedException("Unsupported pooling type");

        var positions = batch.GetLogitPositions();

        if (positions == null)
            throw new InvalidOperationException("GetLogitPositions returned null");

        var embedding = _context.NativeHandle.GetEmbeddingsIth(positions[^1].Item2).ToArray();

        embedding.EuclideanNormalization();

        return embedding;
    }

    private SafeLlavaImageEmbedHandle GetImageEmbeddingHandle(byte[] imageBytes)
    {
        if (_llavaWeights == null)
            throw new InvalidOperationException("LLavaWeights is not loaded.");

        var embeddingsHandle = _llavaWeights.CreateImageEmbeddings(imageBytes);

        if (embeddingsHandle.IsInvalid)
            throw new InvalidOperationException(
                "Failed to create embedding handle, make sure that the image is a valid base 64 encoded string.");

        return embeddingsHandle;
    }

    public async Task<float[]> GetTextEmbedding(string text, CancellationToken cancellationToken) =>
        await GetEmbedding(text, null, cancellationToken).ConfigureAwait(false);

    public async Task<float[]> GetImageEmbedding(byte[] imageBytes, CancellationToken cancellationToken) =>
        await GetEmbedding(null, imageBytes, cancellationToken).ConfigureAwait(false);
}

public sealed class LLamaMultimodalEmbedderTests
{
    private const string ModelPath = "path\to\model.gguf";
    private const string MmprojPath = "path\to\mmproj.gguf";
    private const string ImagePath = "path\to\image.png";
    
    [Fact]
    public async Task TestBasic()
    {
        var parameters = new ModelParams(ModelPath)
        {
            GpuLayerCount = 5
        };

        var model = await LLamaWeights.LoadFromFileAsync(parameters);
        var llavaWeights = await LLavaWeights.LoadFromFileAsync(MmprojPath);
        var context = model.CreateContext(parameters);

        var multimodalEmbedder = new LlamaMultimodalEmbedder(context, llavaWeights);

        var embedding1 = await multimodalEmbedder.GetImageEmbedding(
            await File.ReadAllBytesAsync(ImagePath),
            CancellationToken.None);
        
        var embedding2 = await multimodalEmbedder.GetImageEmbedding(
            await File.ReadAllBytesAsync(ImagePath),
            CancellationToken.None);
        
        var diff = TensorPrimitives.Norm(
            embedding1.Zip(embedding2, (a, b) => a - b).ToArray());
        
        Assert.True(diff < 10e-1);
    }
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions