From 7db35ac86043efb21920db4cdf3acf132c04cce3 Mon Sep 17 00:00:00 2001 From: "caoduanxin.cdx" <13622150883@163.com> Date: Tue, 6 Dec 2022 16:09:43 +0800 Subject: [PATCH] TairVector: support cluster --- pom.xml | 2 +- .../aliyun/tair/tairvector/TairVector.java | 317 +++++-------- .../tair/tairvector/TairVectorCluster.java | 65 +-- .../tair/tairvector/TairVectorPipeline.java | 10 +- .../tair/tairvector/TairVectorShard.java | 350 ++++++++++++++ .../aliyun/tair/tairvector/VectorShard.java | 55 +++ .../factory/VectorBuilderFactory.java | 13 +- .../tair/tairvector/params/HscanParams.java | 12 +- .../tair/tests/example/VectorSearch.java | 2 +- .../tairvector/TairVectorClusterTest.java | 350 +++++++------- .../tairvector/TairVectorPipelineTest.java | 14 +- .../tests/tairvector/TairVectorShardTest.java | 439 ++++++++++++++++++ .../tair/tests/tairvector/TairVectorTest.java | 11 +- .../tests/tairvector/TairVectorTestBase.java | 30 +- 14 files changed, 1193 insertions(+), 477 deletions(-) create mode 100644 src/main/java/com/aliyun/tair/tairvector/TairVectorShard.java create mode 100644 src/main/java/com/aliyun/tair/tairvector/VectorShard.java create mode 100644 src/test/java/com/aliyun/tair/tests/tairvector/TairVectorShardTest.java diff --git a/pom.xml b/pom.xml index ffbcc0e..5f40d7e 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ com.aliyun.tair alibabacloud-tairjedis-sdk - 2.4.0-SNAPSHOT + 3.0.4 jar alibabacloud-tairjedis-sdk diff --git a/src/main/java/com/aliyun/tair/tairvector/TairVector.java b/src/main/java/com/aliyun/tair/tairvector/TairVector.java index 55e66fa..b825e60 100644 --- a/src/main/java/com/aliyun/tair/tairvector/TairVector.java +++ b/src/main/java/com/aliyun/tair/tairvector/TairVector.java @@ -1,12 +1,5 @@ package com.aliyun.tair.tairvector; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; - import com.aliyun.tair.ModuleCommand; import com.aliyun.tair.tairvector.factory.VectorBuilderFactory; import com.aliyun.tair.tairvector.params.DistanceMethod; @@ -15,35 +8,28 @@ import com.aliyun.tair.util.JoinParameters; import redis.clients.jedis.BuilderFactory; import redis.clients.jedis.Jedis; -import redis.clients.jedis.JedisPool; import redis.clients.jedis.ScanResult; import redis.clients.jedis.util.SafeEncoder; +import java.util.*; +import java.util.stream.Collectors; + import static redis.clients.jedis.Protocol.toByteArray; -public class TairVector { +public class TairVector implements VectorShard { private Jedis jedis; - private JedisPool jedisPool; public TairVector(Jedis jedis) { this.jedis = jedis; } - public TairVector(JedisPool jedisPool) { - this.jedisPool = jedisPool; - } - private Jedis getJedis() { - if (jedisPool != null) { - return jedisPool.getResource(); - } return jedis; } - private void releaseJedis(Jedis jedis) { - if (jedisPool != null) { - jedis.close(); - } + @Override + public void quit() { + jedis.quit(); } /** @@ -61,24 +47,16 @@ private void releaseJedis(Jedis jedis) { * M default 16 * @return Success: +OK; Fail: error */ + @Override public String tvscreateindex(final String index, int dims, IndexAlgorithm algorithm, DistanceMethod method, final String... params) { - Jedis jedis = getJedis(); - try { - Object obj = jedis.sendCommand(ModuleCommand.TVSCREATEINDEX, JoinParameters.joinParameters(SafeEncoder.encode(index), toByteArray(dims), SafeEncoder.encode(algorithm.name()), SafeEncoder.encode(method.name()), SafeEncoder.encodeMany(params))); - return BuilderFactory.STRING.build(obj); - } finally { - releaseJedis(jedis); - } + Object obj = getJedis().sendCommand(ModuleCommand.TVSCREATEINDEX, JoinParameters.joinParameters(SafeEncoder.encode(index), toByteArray(dims), SafeEncoder.encode(algorithm.name()), SafeEncoder.encode(method.name()), SafeEncoder.encodeMany(params))); + return BuilderFactory.STRING.build(obj); } + @Override public byte[] tvscreateindex(byte[] index, int dims, IndexAlgorithm algorithm, DistanceMethod method, final byte[]... params) { - Jedis jedis = getJedis(); - try { - Object obj = jedis.sendCommand(ModuleCommand.TVSCREATEINDEX, JoinParameters.joinParameters(index, toByteArray(dims), SafeEncoder.encode(algorithm.name()), SafeEncoder.encode(method.name()), params)); - return BuilderFactory.BYTE_ARRAY.build(obj); - } finally { - releaseJedis(jedis); - } + Object obj = getJedis().sendCommand(ModuleCommand.TVSCREATEINDEX, JoinParameters.joinParameters(index, toByteArray(dims), SafeEncoder.encode(algorithm.name()), SafeEncoder.encode(method.name()), params)); + return BuilderFactory.BYTE_ARRAY.build(obj); } /** @@ -89,24 +67,16 @@ public byte[] tvscreateindex(byte[] index, int dims, IndexAlgorithm algorithm, D * @param index index name * @return Success: string_map, Fail: empty */ + @Override public Map tvsgetindex(final String index) { - Jedis jedis = getJedis(); - try { - Object obj = jedis.sendCommand(ModuleCommand.TVSGETINDEX, SafeEncoder.encode(index)); - return BuilderFactory.STRING_MAP.build(obj); - } finally { - releaseJedis(jedis); - } + Object obj = getJedis().sendCommand(ModuleCommand.TVSGETINDEX, SafeEncoder.encode(index)); + return BuilderFactory.STRING_MAP.build(obj); } + @Override public Map tvsgetindex(byte[] index) { - Jedis jedis = getJedis(); - try { - Object obj = jedis.sendCommand(ModuleCommand.TVSGETINDEX, index); - return BuilderFactory.BYTE_ARRAY_MAP.build(obj); - } finally { - releaseJedis(jedis); - } + Object obj = getJedis().sendCommand(ModuleCommand.TVSGETINDEX, index); + return BuilderFactory.BYTE_ARRAY_MAP.build(obj); } /** @@ -117,24 +87,16 @@ public Map tvsgetindex(byte[] index) { * @param index index name * @return Success: 1; Fail: 0 */ + @Override public Long tvsdelindex(final String index) { - Jedis jedis = getJedis(); - try { - Object obj = jedis.sendCommand(ModuleCommand.TVSDELINDEX, SafeEncoder.encode(index)); - return BuilderFactory.LONG.build(obj); - } finally { - releaseJedis(jedis); - } + Object obj = getJedis().sendCommand(ModuleCommand.TVSDELINDEX, SafeEncoder.encode(index)); + return BuilderFactory.LONG.build(obj); } + @Override public Long tvsdelindex(byte[] index) { - Jedis jedis = getJedis(); - try { - Object obj = jedis.sendCommand(ModuleCommand.TVSDELINDEX, index); - return BuilderFactory.LONG.build(obj); - } finally { - releaseJedis(jedis); - } + Object obj = getJedis().sendCommand(ModuleCommand.TVSDELINDEX, index); + return BuilderFactory.LONG.build(obj); } @@ -151,16 +113,11 @@ public Long tvsdelindex(byte[] index) { * @return A ScanResult. {@link VectorBuilderFactory#SCAN_CURSOR_STRING} */ public ScanResult tvsscanindex(Long cursor, HscanParams params) { - Jedis jedis = getJedis(); - try { - final List args = new ArrayList(); - args.add(toByteArray(cursor)); - args.addAll(params.getParams()); - Object obj = jedis.sendCommand(ModuleCommand.TVSSCANINDEX, args.toArray(new byte[args.size()][])); - return VectorBuilderFactory.SCAN_CURSOR_STRING.build(obj); - } finally { - releaseJedis(jedis); - } + final List args = new ArrayList(); + args.add(toByteArray(cursor)); + args.addAll(params.getParams()); + Object obj = getJedis().sendCommand(ModuleCommand.TVSSCANINDEX, args.toArray(new byte[args.size()][])); + return VectorBuilderFactory.SCAN_CURSOR_STRING.build(obj); } @@ -177,24 +134,16 @@ public ScanResult tvsscanindex(Long cursor, HscanParams params) { * {@literal k} if success, k is the number of fields that were added.. * throw error like "(error) Illegal vector dimensions" if error */ + @Override public Long tvshset(final String index, final String entityid, final String vector, final String... params) { - Jedis jedis = getJedis(); - try { - Object obj = jedis.sendCommand(ModuleCommand.TVSHSET, JoinParameters.joinParameters(SafeEncoder.encode(index), SafeEncoder.encode(entityid), SafeEncoder.encode(VectorBuilderFactory.VECTOR_TAG), SafeEncoder.encode(vector), SafeEncoder.encodeMany(params))); - return BuilderFactory.LONG.build(obj); - } finally { - releaseJedis(jedis); - } + Object obj = getJedis().sendCommand(ModuleCommand.TVSHSET, JoinParameters.joinParameters(SafeEncoder.encode(index), SafeEncoder.encode(entityid), SafeEncoder.encode(VectorBuilderFactory.VECTOR_TAG), SafeEncoder.encode(vector), SafeEncoder.encodeMany(params))); + return BuilderFactory.LONG.build(obj); } + @Override public Long tvshset(byte[] index, byte[] entityid, byte[] vector, final byte[]... params) { - Jedis jedis = getJedis(); - try { - Object obj = jedis.sendCommand(ModuleCommand.TVSHSET, JoinParameters.joinParameters(index, entityid, SafeEncoder.encode(VectorBuilderFactory.VECTOR_TAG), vector, params)); - return BuilderFactory.LONG.build(obj); - } finally { - releaseJedis(jedis); - } + Object obj = getJedis().sendCommand(ModuleCommand.TVSHSET, JoinParameters.joinParameters(index, entityid, SafeEncoder.encode(VectorBuilderFactory.VECTOR_TAG), vector, params)); + return BuilderFactory.LONG.build(obj); } /** @@ -206,24 +155,16 @@ public Long tvshset(byte[] index, byte[] entityid, byte[] vector, final byte[].. * @param entityid entity id * @return Map, an empty list when {@code entityid} does not exist. */ + @Override public Map tvshgetall(final String index, final String entityid) { - Jedis jedis = getJedis(); - try { - Object obj = jedis.sendCommand(ModuleCommand.TVSHGETALL, SafeEncoder.encode(index), SafeEncoder.encode(entityid)); - return BuilderFactory.STRING_MAP.build(obj); - } finally { - releaseJedis(jedis); - } + Object obj = getJedis().sendCommand(ModuleCommand.TVSHGETALL, SafeEncoder.encode(index), SafeEncoder.encode(entityid)); + return BuilderFactory.STRING_MAP.build(obj); } + @Override public Map tvshgetall(byte[] index, byte[] entityid) { - Jedis jedis = getJedis(); - try { - Object obj = jedis.sendCommand(ModuleCommand.TVSHGETALL, index, entityid); - return BuilderFactory.BYTE_ARRAY_MAP.build(obj); - } finally { - releaseJedis(jedis); - } + Object obj = getJedis().sendCommand(ModuleCommand.TVSHGETALL, index, entityid); + return BuilderFactory.BYTE_ARRAY_MAP.build(obj); } /** @@ -236,24 +177,16 @@ public Map tvshgetall(byte[] index, byte[] entityid) { * @param attrs attrs * @return List, an empty list when {@code entityid} or {@code attrs} does not exist . */ + @Override public List tvshmget(final String index, final String entityid, final String... attrs) { - Jedis jedis = getJedis(); - try { - Object obj = jedis.sendCommand(ModuleCommand.TVSHMGET, JoinParameters.joinParameters(SafeEncoder.encode(index), SafeEncoder.encode(entityid), SafeEncoder.encodeMany(attrs))); - return BuilderFactory.STRING_LIST.build(obj); - } finally { - releaseJedis(jedis); - } + Object obj = getJedis().sendCommand(ModuleCommand.TVSHMGET, JoinParameters.joinParameters(SafeEncoder.encode(index), SafeEncoder.encode(entityid), SafeEncoder.encodeMany(attrs))); + return BuilderFactory.STRING_LIST.build(obj); } + @Override public List tvshmget(byte[] index, byte[] entityid, byte[]... attrs) { - Jedis jedis = getJedis(); - try { - Object obj = jedis.sendCommand(ModuleCommand.TVSHMGET, JoinParameters.joinParameters(index, entityid, attrs)); - return BuilderFactory.BYTE_ARRAY_LIST.build(obj); - } finally { - releaseJedis(jedis); - } + Object obj = getJedis().sendCommand(ModuleCommand.TVSHMGET, JoinParameters.joinParameters(index, entityid, attrs)); + return BuilderFactory.BYTE_ARRAY_LIST.build(obj); } @@ -267,24 +200,16 @@ public List tvshmget(byte[] index, byte[] entityid, byte[]... attrs) { * @return Long integer-reply the number of fields that were removed from the tair-vector * not including specified but no existing fields. */ + @Override public Long tvsdel(final String index, final String entityid) { - Jedis jedis = getJedis(); - try { - Object obj = jedis.sendCommand(ModuleCommand.TVSDEL, SafeEncoder.encode(index), SafeEncoder.encode(entityid)); - return BuilderFactory.LONG.build(obj); - } finally { - releaseJedis(jedis); - } + Object obj = getJedis().sendCommand(ModuleCommand.TVSDEL, SafeEncoder.encode(index), SafeEncoder.encode(entityid)); + return BuilderFactory.LONG.build(obj); } + @Override public Long tvsdel(byte[] index, byte[] entityid) { - Jedis jedis = getJedis(); - try { - Object obj = jedis.sendCommand(ModuleCommand.TVSDEL, index, entityid); - return BuilderFactory.LONG.build(obj); - } finally { - releaseJedis(jedis); - } + Object obj = getJedis().sendCommand(ModuleCommand.TVSDEL, index, entityid); + return BuilderFactory.LONG.build(obj); } /** @@ -294,29 +219,20 @@ public Long tvsdel(byte[] index, byte[] entityid) { * * @param index index name * @param entityid entity id - * @param attr attr * @param attrs other attrs * @return Long integer-reply the number of fields that were removed from the tair-vector * not including specified but no existing fields. */ - public Long tvshdel(final String index, final String entityid, final String attr, final String... attrs) { - Jedis jedis = getJedis(); - try { - Object obj = jedis.sendCommand(ModuleCommand.TVSHDEL, JoinParameters.joinParameters(SafeEncoder.encode(index), SafeEncoder.encode(entityid), SafeEncoder.encode(attr), SafeEncoder.encodeMany(attrs))); - return BuilderFactory.LONG.build(obj); - } finally { - releaseJedis(jedis); - } + @Override + public Long tvshdel(final String index, final String entityid, final String... attrs) { + Object obj = getJedis().sendCommand(ModuleCommand.TVSHDEL, JoinParameters.joinParameters(SafeEncoder.encode(index), SafeEncoder.encode(entityid), SafeEncoder.encodeMany(attrs))); + return BuilderFactory.LONG.build(obj); } - public Long tvshdel(byte[] index, byte[] entityid, byte[] attr, byte[]... attrs) { - Jedis jedis = getJedis(); - try { - Object obj = jedis.sendCommand(ModuleCommand.TVSHDEL, JoinParameters.joinParameters(index, entityid, attr, attrs)); - return BuilderFactory.LONG.build(obj); - } finally { - releaseJedis(jedis); - } + @Override + public Long tvshdel(byte[] index, byte[] entityid, byte[]... attrs) { + Object obj = getJedis().sendCommand(ModuleCommand.TVSHDEL, JoinParameters.joinParameters(index, entityid, attrs)); + return BuilderFactory.LONG.build(obj); } @@ -333,32 +249,24 @@ public Long tvshdel(byte[] index, byte[] entityid, byte[] attr, byte[]... attrs) * `NOVAL` - The return result contains no data portion, only cursor information * @return A ScanResult. */ + @Override public ScanResult tvsscan(final String index, Long cursor, HscanParams params) { - Jedis jedis = getJedis(); - try { - final List args = new ArrayList(); - args.add(SafeEncoder.encode(index)); - args.add(toByteArray(cursor)); - args.addAll(params.getParams()); - Object obj = jedis.sendCommand(ModuleCommand.TVSSCAN, args.toArray(new byte[args.size()][])); - return VectorBuilderFactory.SCAN_CURSOR_STRING.build(obj); - } finally { - releaseJedis(jedis); - } + final List args = new ArrayList(); + args.add(SafeEncoder.encode(index)); + args.add(toByteArray(cursor)); + args.addAll(params.getParams()); + Object obj = getJedis().sendCommand(ModuleCommand.TVSSCAN, args.toArray(new byte[args.size()][])); + return VectorBuilderFactory.SCAN_CURSOR_STRING.build(obj); } + @Override public ScanResult tvsscan(byte[] index, Long cursor, HscanParams params) { - Jedis jedis = getJedis(); - try { - final List args = new ArrayList(); - args.add(index); - args.add(toByteArray(cursor)); - args.addAll(params.getParams()); - Object obj = jedis.sendCommand(ModuleCommand.TVSSCAN, args.toArray(new byte[args.size()][])); - return VectorBuilderFactory.SCAN_CURSOR_BYTE.build(obj); - } finally { - releaseJedis(jedis); - } + final List args = new ArrayList(); + args.add(index); + args.add(toByteArray(cursor)); + args.addAll(params.getParams()); + Object obj = getJedis().sendCommand(ModuleCommand.TVSSCAN, args.toArray(new byte[args.size()][])); + return VectorBuilderFactory.SCAN_CURSOR_BYTE.build(obj); } /** @@ -373,10 +281,12 @@ public ScanResult tvsscan(byte[] index, Long cursor, HscanParams params) * ef_search range [0, 1000] * @return VectorBuilderFactory.Knn<> */ + @Override public VectorBuilderFactory.Knn tvsknnsearch(final String index, Long topn, final String vector, final String... params) { return tvsknnsearchfilter(index, topn, vector, "", params); } + @Override public VectorBuilderFactory.Knn tvsknnsearch(byte[] index, Long topn, byte[] vector, final byte[]... params) { return tvsknnsearchfilter(index, topn, vector, SafeEncoder.encode(""), params); } @@ -394,25 +304,17 @@ public VectorBuilderFactory.Knn tvsknnsearch(byte[] index, Long topn, by * ef_search range [0, 1000] * @return VectorBuilderFactory.Knn<> */ + @Override public VectorBuilderFactory.Knn tvsknnsearchfilter(final String index, Long topn, final String vector, final String pattern, final String... params) { - Jedis jedis = getJedis(); - try { - Object obj = jedis.sendCommand(ModuleCommand.TVSKNNSEARCH, JoinParameters.joinParameters(SafeEncoder.encode(index), + Object obj = getJedis().sendCommand(ModuleCommand.TVSKNNSEARCH, JoinParameters.joinParameters(SafeEncoder.encode(index), toByteArray(topn), SafeEncoder.encode(vector), SafeEncoder.encode(pattern), SafeEncoder.encodeMany(params))); - return VectorBuilderFactory.STRING_KNN_RESULT.build(obj); - } finally { - releaseJedis(jedis); - } + return VectorBuilderFactory.STRING_KNN_RESULT.build(obj); } + @Override public VectorBuilderFactory.Knn tvsknnsearchfilter(byte[] index, Long topn, byte[] vector, byte[] pattern, final byte[]... params) { - Jedis jedis = getJedis(); - try { - Object obj = jedis.sendCommand(ModuleCommand.TVSKNNSEARCH, JoinParameters.joinParameters(index, toByteArray(topn), vector, pattern, params)); - return VectorBuilderFactory.BYTE_KNN_RESULT.build(obj); - } finally { - releaseJedis(jedis); - } + Object obj = getJedis().sendCommand(ModuleCommand.TVSKNNSEARCH, JoinParameters.joinParameters(index, toByteArray(topn), vector, pattern, params)); + return VectorBuilderFactory.BYTE_KNN_RESULT.build(obj); } @@ -426,10 +328,12 @@ public VectorBuilderFactory.Knn tvsknnsearchfilter(byte[] index, Long to * ef_search range [0, 1000] * @return Collection<> */ + @Override public Collection> tvsmknnsearch(final String index, Long topn, Collection vectors, final String... params) { return tvsmknnsearchfilter(index, topn, vectors, "", params); } + @Override public Collection> tvsmknnsearch(byte[] index, Long topn, Collection vectors, final byte[]... params) { return tvsmknnsearchfilter(index, topn, vectors, SafeEncoder.encode(""), params); } @@ -445,37 +349,30 @@ public Collection> tvsmknnsearch(byte[] index, * ef_search range [0, 1000] * @return Collection<> */ + @Override public Collection> tvsmknnsearchfilter(final String index, Long topn, Collection vectors, final String pattern, final String... params) { - Jedis jedis = getJedis(); - try { - final List args = new ArrayList(); - args.add(SafeEncoder.encode(index)); - args.add(toByteArray(topn)); - args.add(toByteArray(vectors.size())); - args.addAll(vectors.stream().map(vector -> SafeEncoder.encode(vector)).collect(Collectors.toList())); - args.add(SafeEncoder.encode(pattern)); - args.addAll(Arrays.stream(params).map(str -> SafeEncoder.encode(str)).collect(Collectors.toList())); - Object obj = jedis.sendCommand(ModuleCommand.TVSMKNNSEARCH, args.toArray(new byte[args.size()][])); - return VectorBuilderFactory.STRING_KNN_BATCH_RESULT.build(obj); - } finally { - releaseJedis(jedis); - } + final List args = new ArrayList(); + args.add(SafeEncoder.encode(index)); + args.add(toByteArray(topn)); + args.add(toByteArray(vectors.size())); + args.addAll(vectors.stream().map(vector -> SafeEncoder.encode(vector)).collect(Collectors.toList())); + args.add(SafeEncoder.encode(pattern)); + args.addAll(Arrays.stream(params).map(str -> SafeEncoder.encode(str)).collect(Collectors.toList())); + Object obj = getJedis().sendCommand(ModuleCommand.TVSMKNNSEARCH, args.toArray(new byte[args.size()][])); + return VectorBuilderFactory.STRING_KNN_BATCH_RESULT.build(obj); } + @Override public Collection> tvsmknnsearchfilter(byte[] index, Long topn, Collection vectors, byte[] pattern, final byte[]... params) { - Jedis jedis = getJedis(); - try { - final List args = new ArrayList(); - args.add(index); - args.add(toByteArray(topn)); - args.add(toByteArray(vectors.size())); - args.addAll(vectors); - args.add(pattern); - args.addAll(Arrays.stream(params).collect(Collectors.toList())); - Object obj = jedis.sendCommand(ModuleCommand.TVSMKNNSEARCH, args.toArray(new byte[args.size()][])); - return VectorBuilderFactory.BYTE_KNN_BATCH_RESULT.build(obj); - } finally { - releaseJedis(jedis); - } + final List args = new ArrayList(); + args.add(index); + args.add(toByteArray(topn)); + args.add(toByteArray(vectors.size())); + args.addAll(vectors); + args.add(pattern); + args.addAll(Arrays.stream(params).collect(Collectors.toList())); + Object obj = getJedis().sendCommand(ModuleCommand.TVSMKNNSEARCH, args.toArray(new byte[args.size()][])); + return VectorBuilderFactory.BYTE_KNN_BATCH_RESULT.build(obj); } + } diff --git a/src/main/java/com/aliyun/tair/tairvector/TairVectorCluster.java b/src/main/java/com/aliyun/tair/tairvector/TairVectorCluster.java index a59f1ba..f524f1d 100644 --- a/src/main/java/com/aliyun/tair/tairvector/TairVectorCluster.java +++ b/src/main/java/com/aliyun/tair/tairvector/TairVectorCluster.java @@ -1,14 +1,6 @@ package com.aliyun.tair.tairvector; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; - import com.aliyun.tair.ModuleCommand; -import com.aliyun.tair.tairhash.factory.HashBuilderFactory; import com.aliyun.tair.tairvector.factory.VectorBuilderFactory; import com.aliyun.tair.tairvector.params.DistanceMethod; import com.aliyun.tair.tairvector.params.HscanParams; @@ -19,15 +11,19 @@ import redis.clients.jedis.ScanResult; import redis.clients.jedis.util.SafeEncoder; +import java.util.*; +import java.util.stream.Collectors; + import static redis.clients.jedis.Protocol.toByteArray; -public class TairVectorCluster { +public class TairVectorCluster implements VectorShard { private JedisCluster jc; public TairVectorCluster(JedisCluster jc) { this.jc = jc; } + @Override public void quit() { if (jc != null) { jc.close(); @@ -46,11 +42,13 @@ public void quit() { * @param attrs other columns, optional * @return Success: +OK; Fail: error */ + @Override public String tvscreateindex(final String index, int dims, IndexAlgorithm algorithm, DistanceMethod method, final String... attrs) { Object obj = jc.sendCommand(SafeEncoder.encode(index), ModuleCommand.TVSCREATEINDEX, JoinParameters.joinParameters(SafeEncoder.encode(index), toByteArray(dims), SafeEncoder.encode(algorithm.name()), SafeEncoder.encode(method.name()), SafeEncoder.encodeMany(attrs))); return BuilderFactory.STRING.build(obj); } + @Override public byte[] tvscreateindex(byte[] index, int dims, IndexAlgorithm algorithm, DistanceMethod method, final byte[]... params) { Object obj = jc.sendCommand(index, ModuleCommand.TVSCREATEINDEX, JoinParameters.joinParameters(index, toByteArray(dims), SafeEncoder.encode(algorithm.name()), SafeEncoder.encode(method.name()), params)); return BuilderFactory.BYTE_ARRAY.build(obj); @@ -64,11 +62,13 @@ public byte[] tvscreateindex(byte[] index, int dims, IndexAlgorithm algorithm, D * @param index index name * @return Success: string_map, Fail: empty */ + @Override public Map tvsgetindex(final String index) { Object obj = jc.sendCommand(SafeEncoder.encode(index), ModuleCommand.TVSGETINDEX, SafeEncoder.encode(index)); return BuilderFactory.STRING_MAP.build(obj); } + @Override public Map tvsgetindex(byte[] index) { Object obj = jc.sendCommand(index, ModuleCommand.TVSGETINDEX, index); return BuilderFactory.BYTE_ARRAY_MAP.build(obj); @@ -82,37 +82,18 @@ public Map tvsgetindex(byte[] index) { * @param index index name * @return Success: 1; Fail: 0 */ + @Override public Long tvsdelindex(final String index) { Object obj = jc.sendCommand(SafeEncoder.encode(index), ModuleCommand.TVSDELINDEX, SafeEncoder.encode(index)); return BuilderFactory.LONG.build(obj); } + @Override public Long tvsdelindex(byte[] index) { Object obj = jc.sendCommand(index, ModuleCommand.TVSDELINDEX, index); return BuilderFactory.LONG.build(obj); } - - /** - * TVS.SCANINDEX TVS.SCANINDEX index_name - *

- * scan index - * - * @param cursor start offset - * @param params the params: [MATCH pattern] [COUNT count] - * `MATCH` - Set the pattern which is used to filter the results - * `COUNT` - Set the number of fields in a single scan (default is 10) - * `NOVAL` - The return result contains no data portion, only cursor information - * @return A ScanResult. {@link HashBuilderFactory#EXHSCAN_RESULT_STRING} - */ - public ScanResult tvsscanindex(Long cursor, HscanParams params) { - final List args = new ArrayList(); - args.add(toByteArray(cursor)); - args.addAll(params.getParams()); - Object obj = jc.sendCommand(toByteArray(cursor), ModuleCommand.TVSSCANINDEX, args.toArray(new byte[args.size()][])); - return VectorBuilderFactory.SCAN_CURSOR_STRING.build(obj); - } - /** * TVS.HSET TVS.HSET index entityid vector [(attribute_key attribute_value) ...] *

@@ -126,11 +107,13 @@ public ScanResult tvsscanindex(Long cursor, HscanParams params) { * {@literal k} if success, k is the number of fields that were added.. * throw error like "(error) Illegal vector dimensions" if error */ + @Override public Long tvshset(final String index, final String entityid, final String vector, final String... params) { Object obj = jc.sendCommand(SafeEncoder.encode(index), ModuleCommand.TVSHSET, JoinParameters.joinParameters(SafeEncoder.encode(index), SafeEncoder.encode(entityid), SafeEncoder.encode(VectorBuilderFactory.VECTOR_TAG), SafeEncoder.encode(vector), SafeEncoder.encodeMany(params))); return BuilderFactory.LONG.build(obj); } + @Override public Long tvshset(byte[] index, byte[] entityid, byte[] vector, final byte[]... params) { Object obj = jc.sendCommand(index, ModuleCommand.TVSHSET, JoinParameters.joinParameters(index, entityid, SafeEncoder.encode(VectorBuilderFactory.VECTOR_TAG), vector, params)); return BuilderFactory.LONG.build(obj); @@ -145,11 +128,13 @@ public Long tvshset(byte[] index, byte[] entityid, byte[] vector, final byte[].. * @param entityid entity id * @return Map, an empty list when {@code entityid} does not exist. */ + @Override public Map tvshgetall(final String index, final String entityid) { Object obj = jc.sendCommand(SafeEncoder.encode(index), ModuleCommand.TVSHGETALL, SafeEncoder.encode(index), SafeEncoder.encode(entityid)); return BuilderFactory.STRING_MAP.build(obj); } + @Override public Map tvshgetall(byte[] index, byte[] entityid) { Object obj = jc.sendCommand(index, ModuleCommand.TVSHGETALL, index, entityid); return BuilderFactory.BYTE_ARRAY_MAP.build(obj); @@ -165,11 +150,13 @@ public Map tvshgetall(byte[] index, byte[] entityid) { * @param attrs attrs * @return List, an empty list when {@code entityid} or {@code attrs} does not exist . */ + @Override public List tvshmget(final String index, final String entityid, final String... attrs) { Object obj = jc.sendCommand(SafeEncoder.encode(index), ModuleCommand.TVSHMGET, JoinParameters.joinParameters(SafeEncoder.encode(index), SafeEncoder.encode(entityid), SafeEncoder.encodeMany(attrs))); return BuilderFactory.STRING_LIST.build(obj); } + @Override public List tvshmget(byte[] index, byte[] entityid, byte[]... attrs) { Object obj = jc.sendCommand(index, ModuleCommand.TVSHMGET, JoinParameters.joinParameters(index, entityid, attrs)); return BuilderFactory.BYTE_ARRAY_LIST.build(obj); @@ -186,11 +173,13 @@ public List tvshmget(byte[] index, byte[] entityid, byte[]... attrs) { * @return Long integer-reply the number of fields that were removed from the tair-vector * not including specified but non existing fields. */ + @Override public Long tvsdel(final String index, final String entityid) { Object obj = jc.sendCommand(SafeEncoder.encode(index), ModuleCommand.TVSDEL, SafeEncoder.encode(index), SafeEncoder.encode(entityid)); return BuilderFactory.LONG.build(obj); } + @Override public Long tvsdel(byte[] index, byte[] entityid) { Object obj = jc.sendCommand(index, ModuleCommand.TVSDEL, index, entityid); return BuilderFactory.LONG.build(obj); @@ -207,13 +196,15 @@ public Long tvsdel(byte[] index, byte[] entityid) { * @return Long integer-reply the number of fields that were removed from the tair-vector * not including specified but non existing fields. */ + @Override public Long tvshdel(final String index, final String entityid, final String... attrs) { Object obj = jc.sendCommand(SafeEncoder.encode(index), ModuleCommand.TVSHDEL, JoinParameters.joinParameters(SafeEncoder.encode(index), SafeEncoder.encode(entityid), SafeEncoder.encodeMany(attrs))); return BuilderFactory.LONG.build(obj); } + @Override public Long tvshdel(byte[] index, byte[] entityid, byte[]... attrs) { - Object obj = jc.sendCommand(index, ModuleCommand.TVSHDEL, JoinParameters.joinParameters(index, entityid, attrs)); + Object obj = jc.sendCommand(index, ModuleCommand.TVSHDEL, JoinParameters.joinParameters(index, entityid,attrs)); return BuilderFactory.LONG.build(obj); } @@ -231,6 +222,7 @@ public Long tvshdel(byte[] index, byte[] entityid, byte[]... attrs) { * `NOVAL` - The return result contains no data portion, only cursor information * @return A ScanResult. */ + @Override public ScanResult tvsscan(final String index, Long cursor, HscanParams params) { final List args = new ArrayList(); args.add(SafeEncoder.encode(index)); @@ -240,6 +232,7 @@ public ScanResult tvsscan(final String index, Long cursor, HscanParams p return VectorBuilderFactory.SCAN_CURSOR_STRING.build(obj); } + @Override public ScanResult tvsscan(byte[] index, Long cursor, HscanParams params) { final List args = new ArrayList(); args.add(index); @@ -261,10 +254,12 @@ public ScanResult tvsscan(byte[] index, Long cursor, HscanParams params) * ef_search range [0, 1000] * @return VectorBuilderFactory.Knn<> */ + @Override public VectorBuilderFactory.Knn tvsknnsearch(final String index, Long topn, final String vector, final String... params) { return tvsknnsearchfilter(index, topn, vector, "", params); } + @Override public VectorBuilderFactory.Knn tvsknnsearch(byte[] index, Long topn, byte[] vector, final byte[]... params) { return tvsknnsearchfilter(index, topn, vector, SafeEncoder.encode(""), params); } @@ -282,12 +277,14 @@ public VectorBuilderFactory.Knn tvsknnsearch(byte[] index, Long topn, by * ef_search range [0, 1000] * @return VectorBuilderFactory.Knn<> */ + @Override public VectorBuilderFactory.Knn tvsknnsearchfilter(final String index, Long topn, final String vector, final String pattern, final String... params) { Object obj = jc.sendCommand(SafeEncoder.encode(index), ModuleCommand.TVSKNNSEARCH, JoinParameters.joinParameters(SafeEncoder.encode(index), toByteArray(topn), SafeEncoder.encode(vector), SafeEncoder.encode(pattern), SafeEncoder.encodeMany(params))); return VectorBuilderFactory.STRING_KNN_RESULT.build(obj); } + @Override public VectorBuilderFactory.Knn tvsknnsearchfilter(byte[] index, Long topn, byte[] vector, byte[] pattern, final byte[]... params) { Object obj = jc.sendCommand(index, ModuleCommand.TVSKNNSEARCH, JoinParameters.joinParameters(index, toByteArray(topn), vector, pattern, params)); return VectorBuilderFactory.BYTE_KNN_RESULT.build(obj); @@ -303,10 +300,12 @@ public VectorBuilderFactory.Knn tvsknnsearchfilter(byte[] index, Long to * ef_search range [0, 1000] * @return Collection<> */ + @Override public Collection> tvsmknnsearch(final String index, Long topn, Collection vectors, final String... params) { return tvsmknnsearchfilter(index, topn, vectors, "", params); } + @Override public Collection> tvsmknnsearch(byte[] index, Long topn, Collection vectors, final byte[]... params) { return tvsmknnsearchfilter(index, topn, vectors, SafeEncoder.encode(""), params); } @@ -322,6 +321,7 @@ public Collection> tvsmknnsearch(byte[] index, * ef_search range [0, 1000] * @return Collection<> */ + @Override public Collection> tvsmknnsearchfilter(final String index, Long topn, Collection vectors, final String pattern, final String... params) { final List args = new ArrayList(); args.add(SafeEncoder.encode(index)); @@ -334,6 +334,7 @@ public Collection> tvsmknnsearchfilter(final St return VectorBuilderFactory.STRING_KNN_BATCH_RESULT.build(obj); } + @Override public Collection> tvsmknnsearchfilter(byte[] index, Long topn, Collection vectors, byte[] pattern, final byte[]... params) { final List args = new ArrayList(); args.add(index); @@ -345,4 +346,6 @@ public Collection> tvsmknnsearchfilter(byte[] i Object obj = jc.sendCommand(index, ModuleCommand.TVSMKNNSEARCH, args.toArray(new byte[args.size()][])); return VectorBuilderFactory.BYTE_KNN_BATCH_RESULT.build(obj); } + + } diff --git a/src/main/java/com/aliyun/tair/tairvector/TairVectorPipeline.java b/src/main/java/com/aliyun/tair/tairvector/TairVectorPipeline.java index 2103f81..1aae0c2 100644 --- a/src/main/java/com/aliyun/tair/tairvector/TairVectorPipeline.java +++ b/src/main/java/com/aliyun/tair/tairvector/TairVectorPipeline.java @@ -1,12 +1,5 @@ package com.aliyun.tair.tairvector; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; - import com.aliyun.tair.ModuleCommand; import com.aliyun.tair.tairhash.factory.HashBuilderFactory; import com.aliyun.tair.tairvector.factory.VectorBuilderFactory; @@ -20,6 +13,9 @@ import redis.clients.jedis.ScanResult; import redis.clients.jedis.util.SafeEncoder; +import java.util.*; +import java.util.stream.Collectors; + import static redis.clients.jedis.Protocol.toByteArray; public class TairVectorPipeline extends Pipeline { diff --git a/src/main/java/com/aliyun/tair/tairvector/TairVectorShard.java b/src/main/java/com/aliyun/tair/tairvector/TairVectorShard.java new file mode 100644 index 0000000..2fdba65 --- /dev/null +++ b/src/main/java/com/aliyun/tair/tairvector/TairVectorShard.java @@ -0,0 +1,350 @@ +package com.aliyun.tair.tairvector; + +import com.aliyun.tair.tairvector.factory.VectorBuilderFactory; +import com.aliyun.tair.tairvector.factory.VectorBuilderFactory.Knn; +import com.aliyun.tair.tairvector.params.DistanceMethod; +import com.aliyun.tair.tairvector.params.HscanParams; +import com.aliyun.tair.tairvector.params.IndexAlgorithm; +import redis.clients.jedis.ScanResult; +import redis.clients.jedis.util.JedisClusterCRC16; +import redis.clients.jedis.util.SafeEncoder; + +import java.util.*; + +public class TairVectorShard { + private VectorShard vectirInstance; + private int shardCount; + + public TairVectorShard(VectorShard vectirInstance, int shardCount) { + this.vectirInstance = vectirInstance; + if (shardCount < 1) { + throw new IllegalArgumentException("shards should not be less than 1"); + } + this.shardCount = shardCount; + } + + public void quit() { + this.vectirInstance.quit(); + } + + public String tvscreateindex(String index, int dims, IndexAlgorithm algorithm, DistanceMethod method, String... params) { + List indexNames = null; + indexNames = defaultindexsplit(index, shardCount); + + String result = null; + for (String indexName : indexNames) { + result = this.vectirInstance.tvscreateindex(indexName, dims, algorithm, method, params); + if (!result.equals("OK")) { + //TODO delete pre index + return result; + } + } + return result; + } + + public byte[] tvscreateindex(byte[] index, int dims, IndexAlgorithm algorithm, DistanceMethod method, byte[]... params) { + List indexNames = null; + indexNames = defaultindexsplit(SafeEncoder.encode(index), shardCount); + + byte[] result = null; + for (String indexName : indexNames) { + result = this.vectirInstance.tvscreateindex(SafeEncoder.encode(indexName), dims, algorithm, method, params); + if (!SafeEncoder.encode(result).equals("OK")) { + //TODO delete pre index + return result; + } + } + return result; + } + + public List> tvsgetindex(String index) { + List indexNames = null; + indexNames = defaultindexsplit(index, shardCount); + + List> results = new ArrayList<>(); + for (String indexName : indexNames) { + Map result = this.vectirInstance.tvsgetindex(indexName); + if (result == null || result.isEmpty()) + continue; + else + results.add(result); + } + return results; + } + + public List> tvsgetindex(byte[] index) { + List indexNames = null; + indexNames = defaultindexsplit(SafeEncoder.encode(index), shardCount); + + List> results = new ArrayList<>(); + for (String indexName : indexNames) { + Map result = this.vectirInstance.tvsgetindex(SafeEncoder.encode(indexName)); + if (result == null || result.isEmpty()) + continue; + else + results.add(result); + } + return results; + } + + public Long tvsdelindex(String index) { + List indexNames = null; + indexNames = defaultindexsplit(index, shardCount); + + Long result = new Long(0); + for (String indexName : indexNames) { + result += this.vectirInstance.tvsdelindex(indexName); + } + return result; + } + + public Long tvsdelindex(byte[] index) { + List indexNames = null; + indexNames = defaultindexsplit(SafeEncoder.encode(index), shardCount); + + Long result = new Long(0); + for (String indexName : indexNames) { + result += this.vectirInstance.tvsdelindex(SafeEncoder.encode(indexName)); + } + return result; + } + + public Long tvshset(String index, String key, String vector, String... params) { + List indexNames = null; + indexNames = defaultindexsplit(index, shardCount); + + int slotId = JedisClusterCRC16.getSlot(key); + String indexName = indexNames.get(slotId % indexNames.size()); + return this.vectirInstance.tvshset(indexName, key, vector, params); + } + + public Long tvshset(byte[] index, byte[] key, byte[] vector, byte[]... params) { + List indexNames = null; + indexNames = defaultindexsplit(SafeEncoder.encode(index), shardCount); + + int slotId = JedisClusterCRC16.getSlot(key); + String indexName = indexNames.get(slotId % indexNames.size()); + return this.vectirInstance.tvshset(SafeEncoder.encode(indexName), key, vector, params); + } + + public Map tvshgetall(String index, String key) { + List indexNames = null; + indexNames = defaultindexsplit(index, shardCount); + + int slotId = JedisClusterCRC16.getSlot(key); + String indexName = indexNames.get(slotId % indexNames.size()); + return this.vectirInstance.tvshgetall(indexName, key); + } + + public Map tvshgetall(byte[] index, byte[] key) { + List indexNames = null; + indexNames = defaultindexsplit(SafeEncoder.encode(index), shardCount); + + int slotId = JedisClusterCRC16.getSlot(key); + String indexName = indexNames.get(slotId % indexNames.size()); + return this.vectirInstance.tvshgetall(SafeEncoder.encode(indexName), key); + } + + public List tvshmget(String index, String key, String... attrs) { + List indexNames = null; + indexNames = defaultindexsplit(index, shardCount); + + int slotId = JedisClusterCRC16.getSlot(key); + String indexName = indexNames.get(slotId % indexNames.size()); + return this.vectirInstance.tvshmget(indexName, key, attrs); + } + + public List tvshmget(byte[] index, byte[] key, byte[]... attrs) { + List indexNames = null; + indexNames = defaultindexsplit(SafeEncoder.encode(index), shardCount); + + int slotId = JedisClusterCRC16.getSlot(key); + String indexName = indexNames.get(slotId % indexNames.size()); + return this.vectirInstance.tvshmget(SafeEncoder.encode(indexName), key, attrs); + } + + public Long tvsdel(String index, String key) { + List indexNames = null; + indexNames = defaultindexsplit(index, shardCount); + + int slotId = JedisClusterCRC16.getSlot(key); + String indexName = indexNames.get(slotId % indexNames.size()); + return this.vectirInstance.tvsdel(indexName, key); + } + + public Long tvsdel(byte[] index, byte[] key) { + List indexNames = null; + indexNames = defaultindexsplit(SafeEncoder.encode(index), shardCount); + + int slotId = JedisClusterCRC16.getSlot(key); + String indexName = indexNames.get(slotId % indexNames.size()); + return this.vectirInstance.tvsdel(SafeEncoder.encode(indexName), key); + } + + + public Long tvshdel(String index, String key, String... attrs) { + List indexNames = null; + indexNames = defaultindexsplit(index, shardCount); + + int slotId = JedisClusterCRC16.getSlot(key); + String indexName = indexNames.get(slotId % indexNames.size()); + return this.vectirInstance.tvshdel(indexName, key, attrs); + } + + public Long tvshdel(byte[] index, byte[] key, byte[]... attrs) { + List indexNames = null; + indexNames = defaultindexsplit(SafeEncoder.encode(index), shardCount); + + int slotId = JedisClusterCRC16.getSlot(key); + String indexName = indexNames.get(slotId % indexNames.size()); + return this.vectirInstance.tvshdel(SafeEncoder.encode(indexName), key, attrs); + } + + public List> tvsscan(String index, Long cursor, HscanParams params) { + List indexNames = null; + indexNames = defaultindexsplit(index, shardCount); + + List> results = new ArrayList<>(); + for (String indexName : indexNames) { + ScanResult result = this.vectirInstance.tvsscan(indexName, cursor, params); + if (result == null) + continue; + else + results.add(result); + } + return results; + } + + + public List> tvsscan(byte[] index, Long cursor, HscanParams params) { + List indexNames = null; + indexNames = defaultindexsplit(SafeEncoder.encode(index), shardCount); + + List> results = new ArrayList<>(); + for (String indexName : indexNames) { + ScanResult result = this.vectirInstance.tvsscan(SafeEncoder.encode(indexName), cursor, params); + if (result == null) + continue; + else + results.add(result); + } + return results; + } + + public Knn tvsknnsearch(String index, Long topn, String vector, String... params) { + return tvsknnsearchfilter(index, topn, vector, "", params); + } + + public Knn tvsknnsearch(byte[] index, Long topn, byte[] vector, byte[]... params) { + return tvsknnsearchfilter(index, topn, vector, SafeEncoder.encode(""), params); + } + + public Collection> tvsmknnsearch(String index, Long topn, Collection vectors, String... params) { + return tvsmknnsearchfilter(index, topn, vectors, "", params); + } + + public Collection> tvsmknnsearch(byte[] index, Long topn, Collection vectors, byte[]... params) { + return tvsmknnsearchfilter(index, topn, vectors, SafeEncoder.encode(""), params); + } + + public Knn tvsknnsearchfilter(final String index, Long topn, final String vector, final String pattern, final String... params) { + List indexNames = null; + indexNames = defaultindexsplit(index, shardCount); + Long shardTopN = topnforshard(topn, shardCount); + + List> rets = new ArrayList<>(); + for (int i = 0; i < indexNames.size(); ++i) { + rets.add(this.vectirInstance.tvsknnsearchfilter(indexNames.get(i), shardTopN, vector, pattern, params)); + } + return mergeSearchResult(rets, topn); + } + + public Knn tvsknnsearchfilter(byte[] index, Long topn, byte[] vector, byte[] pattern, final byte[]... params) { + List indexNames = null; + indexNames = defaultindexsplit(SafeEncoder.encode(index), shardCount); + Long shardTopN = topnforshard(topn, shardCount); + + List> rets = new ArrayList<>(); + for (int i = 0; i < indexNames.size(); ++i) { + rets.add(this.vectirInstance.tvsknnsearchfilter(SafeEncoder.encode(indexNames.get(i)), shardTopN, vector, pattern, params)); + } + return mergeSearchResult(rets, topn); + } + + public Collection> tvsmknnsearchfilter(final String index, Long topn, Collection vectors, final String pattern, final String... params) { + List indexNames = null; + indexNames = defaultindexsplit(index, shardCount); + Long shardTopN = topnforshard(topn, shardCount); + + List>> rets = new ArrayList<>(); + for (int i = 0; i < vectors.size(); ++i) { rets.add(new ArrayList<>()); } + + for (int i = 0; i < indexNames.size(); ++i) { + Collection> shardRet = this.vectirInstance.tvsmknnsearchfilter(indexNames.get(i), shardTopN, vectors, pattern, params); + int vectorIdx = 0; + for (Knn vectorRet : shardRet) { + rets.get(vectorIdx).add(vectorRet); + vectorIdx++; + } + } + Collection> result = new ArrayList<>(); + for (List> ret : rets) { + result.add(mergeSearchResult(ret, topn)); + } + return result; + } + + public Collection> tvsmknnsearchfilter(byte[] index, Long topn, Collection vectors, byte[] pattern, final byte[]... params) { + List indexNames = null; + indexNames = defaultindexsplit(SafeEncoder.encode(index), shardCount); + Long shardTopN = topnforshard(topn, shardCount); + + List>> rets = new ArrayList<>(); + for (int i = 0; i < vectors.size(); ++i) { rets.add(new ArrayList<>()); } + + for (int i = 0; i < indexNames.size(); ++i) { + Collection> shardRet = this.vectirInstance.tvsmknnsearchfilter(SafeEncoder.encode(indexNames.get(i)), shardTopN, vectors, pattern, params); + int vectorIdx = 0; + for (Knn vectorRet : shardRet) { + rets.get(vectorIdx).add(vectorRet); + vectorIdx++; + } + } + Collection> result = new ArrayList<>(); + for (List> ret : rets) { + result.add(mergeSearchResult(ret, topn)); + } + + return result; + } + + static public List defaultindexsplit(final String index, final int shards) { + List nameList = new ArrayList<>(); + for (int i = 0; i < shards; ++i) { + nameList.add(String.join("_", index, String.valueOf(i))); + } + return nameList; + } + + static public Long topnforshard(final Long topn, final int shards) { + Long shardTopN = (long)Math.ceil(topn / shards * 1.1); + return shardTopN; + } + + static public Knn mergeSearchResult(List> rets, Long topn) { + Queue> queue = new PriorityQueue<>(); + for (Knn ret : rets) { + for (VectorBuilderFactory.KnnItem item : ret.getKnnResults()) { + queue.add(item); + } + } + + Knn mergeRets = new Knn<>(); + int count = queue.size(); + for (int i = 0; i < topn && i < count; ++i) { + mergeRets.add(queue.poll()); + } + + return mergeRets; + } +} diff --git a/src/main/java/com/aliyun/tair/tairvector/VectorShard.java b/src/main/java/com/aliyun/tair/tairvector/VectorShard.java new file mode 100644 index 0000000..80626ca --- /dev/null +++ b/src/main/java/com/aliyun/tair/tairvector/VectorShard.java @@ -0,0 +1,55 @@ +package com.aliyun.tair.tairvector; + +import com.aliyun.tair.tairvector.factory.VectorBuilderFactory; +import com.aliyun.tair.tairvector.params.DistanceMethod; +import com.aliyun.tair.tairvector.params.HscanParams; +import com.aliyun.tair.tairvector.params.IndexAlgorithm; +import redis.clients.jedis.ScanResult; + +import java.util.Collection; +import java.util.List; +import java.util.Map; + +public interface VectorShard { + public void quit(); + public String tvscreateindex(final String index, int dims, IndexAlgorithm algorithm, DistanceMethod method, final String... params); + public byte[] tvscreateindex(byte[] index, int dims, IndexAlgorithm algorithm, DistanceMethod method, final byte[]... params); + + public Map tvsgetindex(final String index); + public Map tvsgetindex(final byte[] index); + + public Long tvsdelindex(final String index); + public Long tvsdelindex(final byte[] index); + + // public ScanResult tvsscanindex(Long cursor, final HscanParams params); + + public Long tvshset(final String index,final String key, final String vector, final String... params); + public Long tvshset(final byte[] index,final byte[] key, final byte[] vector, final byte[]... params); + + public Map tvshgetall(final String index, final String key); + public Map tvshgetall(final byte[] index, final byte[] key); + + public List tvshmget(final String index, final String key, final String... attrs); + public List tvshmget(final byte[] index, final byte[] key, final byte[]... attrs); + + public Long tvsdel(final String index, final String key); + public Long tvsdel(final byte[] index, final byte[] key); + + public Long tvshdel(final String index, final String key, final String... attrs); + public Long tvshdel(final byte[] index, final byte[] key, final byte[]... attrs); + + public ScanResulttvsscan(final String index, Long cursor, final HscanParams params); + public ScanResulttvsscan(final byte[] index, Long cursor, final HscanParams params); + + public VectorBuilderFactory.Knn tvsknnsearch(final String index, final Long topn, final String vector, final String... params); + public VectorBuilderFactory.Knn tvsknnsearch(final byte[] index, final Long topn, final byte[] vector, final byte[]... params); + + public VectorBuilderFactory.Knn tvsknnsearchfilter(final String index, Long topn, final String vector, final String pattern, final String... params); + public VectorBuilderFactory.Knn tvsknnsearchfilter(byte[] index, Long topn, byte[] vector, byte[] pattern, final byte[]... params); + + public Collection> tvsmknnsearch(final String index, final Long topn, final Collection vectors, final String... params); + public Collection> tvsmknnsearch(final byte[] index, final Long topn, final Collection vectors, final byte[]... params); + + public Collection> tvsmknnsearchfilter(final String index, Long topn, Collection vectors, final String pattern, final String... params); + public Collection> tvsmknnsearchfilter(byte[] index, Long topn, Collection vectors, byte[] pattern, final byte[]... params); + } diff --git a/src/main/java/com/aliyun/tair/tairvector/factory/VectorBuilderFactory.java b/src/main/java/com/aliyun/tair/tairvector/factory/VectorBuilderFactory.java index adb3495..e85f5df 100644 --- a/src/main/java/com/aliyun/tair/tairvector/factory/VectorBuilderFactory.java +++ b/src/main/java/com/aliyun/tair/tairvector/factory/VectorBuilderFactory.java @@ -1,17 +1,17 @@ package com.aliyun.tair.tairvector.factory; +import redis.clients.jedis.Builder; +import redis.clients.jedis.ScanResult; +import redis.clients.jedis.util.SafeEncoder; + import java.util.ArrayList; import java.util.Collection; import java.util.Iterator; import java.util.List; -import redis.clients.jedis.Builder; -import redis.clients.jedis.ScanResult; -import redis.clients.jedis.util.SafeEncoder; - public class VectorBuilderFactory { public static final String VECTOR_TAG = "VECTOR"; - public static class KnnItem { + public static class KnnItem implements Comparable> { private T id; private double score; public KnnItem(T id, double score) { @@ -31,6 +31,9 @@ public double getScore() { public String toString() { return "id =" + id + ", score =" + score + ";"; } + + @Override + public int compareTo(KnnItem other) { return Double.compare(this.score, other.score); } } public static class Knn { diff --git a/src/main/java/com/aliyun/tair/tairvector/params/HscanParams.java b/src/main/java/com/aliyun/tair/tairvector/params/HscanParams.java index bb9b7ae..50dfa22 100644 --- a/src/main/java/com/aliyun/tair/tairvector/params/HscanParams.java +++ b/src/main/java/com/aliyun/tair/tairvector/params/HscanParams.java @@ -1,17 +1,11 @@ package com.aliyun.tair.tairvector.params; -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; -import java.util.Iterator; -import java.util.List; -import java.util.Map; - import redis.clients.jedis.Protocol; import redis.clients.jedis.util.SafeEncoder; +import java.nio.ByteBuffer; +import java.util.*; + public class HscanParams { private final static String MATCH = "MATCH"; diff --git a/src/test/java/com/aliyun/tair/tests/example/VectorSearch.java b/src/test/java/com/aliyun/tair/tests/example/VectorSearch.java index 5d8ab99..527db02 100644 --- a/src/test/java/com/aliyun/tair/tests/example/VectorSearch.java +++ b/src/test/java/com/aliyun/tair/tests/example/VectorSearch.java @@ -27,7 +27,7 @@ public class VectorSearch { jedisPool = new JedisPool(config, HOST, PORT, DEFAULT_CONNECTION_TIMEOUT, DEFAULT_SO_TIMEOUT, PASSWORD, 0, null); - tairVector = new TairVector(jedisPool); + tairVector = new TairVector(jedisPool.getResource()); } public static boolean createIndex(final String index, int dims, final String... attrs) { diff --git a/src/test/java/com/aliyun/tair/tests/tairvector/TairVectorClusterTest.java b/src/test/java/com/aliyun/tair/tests/tairvector/TairVectorClusterTest.java index 9c38756..77b5ca2 100644 --- a/src/test/java/com/aliyun/tair/tests/tairvector/TairVectorClusterTest.java +++ b/src/test/java/com/aliyun/tair/tests/tairvector/TairVectorClusterTest.java @@ -1,12 +1,5 @@ package com.aliyun.tair.tests.tairvector; -import java.util.Arrays; -import java.util.Collection; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; - import com.aliyun.tair.tairvector.factory.VectorBuilderFactory; import com.aliyun.tair.tairvector.params.DistanceMethod; import com.aliyun.tair.tairvector.params.HscanParams; @@ -15,29 +8,49 @@ import redis.clients.jedis.ScanResult; import redis.clients.jedis.util.SafeEncoder; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertTrue; public class TairVectorClusterTest extends TairVectorTestBase { - final String index = "default_index_cluster"; + final String index = "default_index"; final int dims = 8; final IndexAlgorithm algorithm = IndexAlgorithm.HNSW; final DistanceMethod method = DistanceMethod.IP; + final long dbid = 2; + final List index_params = Arrays.asList("ef_construct", "100", "M", "16"); + final List index_params_with_dataType = Arrays.asList("ef_construct", "100", "M", "16","data_type","BINARY"); final List ef_params = Arrays.asList("ef_search", "100"); - private void tvs_create_index(int dims, IndexAlgorithm algorithm, DistanceMethod method) { - assertEquals("OK", tairVectorCluster.tvscreateindex(index, dims, algorithm, method)); + /** + * 127.0.0.1:6379> tvs.createindex default_index 8 HNSW IP + */ + private void tvs_create_index(int dims, IndexAlgorithm algorithm, DistanceMethod method, final String... attr) { + tairVectorCluster.tvsdelindex(index); + assertEquals("OK", tairVectorCluster.tvscreateindex(index, dims, algorithm, method, attr)); + } + + private void check_index(int dims, IndexAlgorithm algorithm, DistanceMethod method, final String... attr) { + Map objs = tairVectorCluster.tvsgetindex(index); + if (!objs.isEmpty()) { + long result = tairVectorCluster.tvsdelindex(index); + assertEquals(result, 1); + } + assertEquals("OK", tairVectorCluster.tvscreateindex(index, dims, algorithm, method, attr)); } private void tvs_hset(final String entityid, final String vector, final String param_k, final String param_v) { long result = tairVectorCluster.tvshset(index, entityid, vector, param_k, param_v); - assertEquals(result, 1); + assertEquals(result, 2); } private void tvs_hset(byte[] entityid, byte[] vector, byte[] param_k, byte[] param_v) { long result = tairVectorCluster.tvshset(SafeEncoder.encode(index), entityid, vector, param_k, param_v); - assertEquals(result, 1); + assertTrue(result <= 2); } private long tvs_del_entity(String entity) { @@ -49,15 +62,49 @@ private long tvs_del_entity(byte[] entity) { } @Test - public void tvs_create_index_test() { - assertEquals("OK", tairVectorCluster.tvscreateindex(index, dims, algorithm, method)); - assertNotEquals("OK", tairVectorCluster.tvscreateindex(SafeEncoder.encode(index), dims, algorithm, method)); + public void tvs_create_index() { + tvs_del_index(); + assertEquals("OK", tairVectorCluster.tvscreateindex(index, dims, algorithm, method, index_params.toArray(new String[0]))); + try { + tairVectorCluster.tvscreateindex(SafeEncoder.encode(index), dims, algorithm, method); + } catch (Exception e) { + assertEquals(e.getMessage(), "ERR duplicated index key"); + } } + @Test + public void tvs_create_index_with_datatype() { + tvs_del_index(); + try { + tairVectorCluster.tvscreateindex(index, dims, algorithm, method, index_params_with_dataType.toArray(new String[0])); + }catch (Exception e){ + assertEquals(e.getMessage(), "ERR index parameters invalid"); + } + assertEquals("OK", tairVectorCluster.tvscreateindex(index, dims, algorithm, DistanceMethod.JACCARD, index_params_with_dataType.toArray(new String[0]))); + try { + tairVectorCluster.tvscreateindex(SafeEncoder.encode(index), dims, algorithm, method); + } catch (Exception e) { + assertEquals(e.getMessage(), "ERR duplicated index key"); + } + } + + @Test + public void tvs_create_index_withoption_args() { + tvs_del_index(); + assertEquals("OK", tairVectorCluster.tvscreateindex(index, dims, algorithm, method, + "ef_construct", "50", "M", "20")); + Map schema = tairVectorCluster.tvsgetindex(index); + assertEquals(String.valueOf(50), schema.get("ef_construct")); + assertEquals(String.valueOf(20), schema.get("M")); + } + + /** + * 127.0.0.1:6379> tvs.getindex default_index + */ @Test public void tvs_get_index() { - tvs_create_index(dims, algorithm, method); + tvs_create_index(dims, algorithm, method, index_params.toArray(new String[0])); Map schema = tairVectorCluster.tvsgetindex(index); assertEquals(index, schema.get("index_name")); @@ -74,21 +121,9 @@ public void tvs_get_index() { } } - @Test - public void tvs_scan_index() { - tvs_create_index(dims, algorithm, method); - - HscanParams exhscanParams = new HscanParams(); - exhscanParams.count(5); - ScanResult result = tairVectorCluster.tvsscanindex(0L, exhscanParams); - assertEquals(String.valueOf(1), result.getCursor()); - assertEquals(1, result.getResult().size()); - assertEquals(index, result.getResult().get(0)); - } - @Test public void tvs_del_index() { - tvs_create_index(dims, algorithm, method); + check_index(dims, algorithm, method, index_params.toArray(new String[0])); Map schema = tairVectorCluster.tvsgetindex(index); assertEquals(index, schema.get("index_name")); @@ -96,215 +131,144 @@ public void tvs_del_index() { assertEquals(method.name(), schema.get("distance_method")); assertEquals(String.valueOf(0), schema.get("data_count")); - assertEquals((long) tairVectorCluster.tvsdelindex(index), 1L); + long result = tairVectorCluster.tvsdelindex(index); + assertEquals(result, 1); + long result_byte = tairVectorCluster.tvsdelindex(SafeEncoder.encode(index)); + assertEquals(result_byte, 0); + } + + @Test + public void tvs_hset_data_bin() { + check_index(dims, algorithm, DistanceMethod.JACCARD, index_params_with_dataType.toArray(new String[0])); + tvs_del_entity("fourth_entity_knn"); + tvs_hset("fourth_entity_knn", "[1,1,0,0,1,0,1,0]", "name", "sammy"); + tvs_del_entity("ten_entity_knn"); + tvs_hset(SafeEncoder.encode("ten_entity_knn"), SafeEncoder.encode("[1,1,0,0,1,0,1,0]"), + SafeEncoder.encode("name"), SafeEncoder.encode("tiddy")); + } + + @Test + public void tvs_hgetall_data_bin() { + tvs_del_entity("first_entity_knn"); + tvs_del_entity("second_entity_knn"); + tvs_hset("first_entity_knn", "[1,1,1,1,0,0,0,0]", "name", "sammy"); + tvs_hset(SafeEncoder.encode("second_entity_knn"), SafeEncoder.encode("[1,1,1,1,0,0,0,0]"), + SafeEncoder.encode("name"), SafeEncoder.encode("tiddy")); + + Map entity_string = tairVectorCluster.tvshgetall(index, "first_entity_knn"); + assertEquals("[1,1,1,1,0,0,0,0]", entity_string.get(VectorBuilderFactory.VECTOR_TAG)); + assertEquals("sammy", entity_string.get("name")); + + Map entity_byte = tairVectorCluster.tvshgetall(SafeEncoder.encode(index), SafeEncoder.encode("first_entity_knn")); + assertEquals("[1,1,1,1,0,0,0,0]", SafeEncoder.encode(entity_byte.get(SafeEncoder.encode(VectorBuilderFactory.VECTOR_TAG)))); + assertEquals("sammy", SafeEncoder.encode(entity_byte.get(SafeEncoder.encode("name")))); } @Test public void tvs_hset() { - tvs_del_entity("first_entity"); - tvs_del_entity("second_entity"); - tvs_hset("first_entity", "[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "name", "sammy"); - tvs_hset(SafeEncoder.encode("second_entity"), SafeEncoder.encode("[0.22, 0.33, 0.66, 0.77, 0.88, 0.89, 0.11, 0.89]"), SafeEncoder.encode("name"), SafeEncoder.encode("tiddy")); + check_index(dims, algorithm, method, index_params.toArray(new String[0])); + tvs_del_entity("fourth_entity_knn"); + tvs_hset("fourth_entity_knn", "[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "name", "sammy"); + tvs_del_entity("ten_entity_knn"); + tvs_hset(SafeEncoder.encode("ten_entity_knn"), SafeEncoder.encode("[0.22, 0.33, 0.66, 0.77, 0.88, 0.89, 0.11, 0.89]"), + SafeEncoder.encode("name"), SafeEncoder.encode("tiddy")); } @Test public void tvs_hgetall() { - tvs_del_entity("first_entity"); - tvs_del_entity("second_entity"); - tvs_hset("first_entity", "[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "name", "sammy"); - tvs_hset(SafeEncoder.encode("second_entity"), SafeEncoder.encode("[0.22, 0.33, 0.66, 0.77, 0.88, 0.89, 0.11, 0.89]"), SafeEncoder.encode("name"), SafeEncoder.encode("tiddy")); + tvs_del_entity("first_entity_knn"); + tvs_del_entity("second_entity_knn"); + tvs_hset("first_entity_knn", "[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "name", "sammy"); + tvs_hset(SafeEncoder.encode("second_entity_knn"), SafeEncoder.encode("[0.22, 0.33, 0.66, 0.77, 0.88, 0.89, 0.11, 0.89]"), + SafeEncoder.encode("name"), SafeEncoder.encode("tiddy")); - Map entity_string = tairVectorCluster.tvshgetall(index, "first_entity"); - assertEquals("[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", entity_string.get(VectorBuilderFactory.VECTOR_TAG)); + Map entity_string = tairVectorCluster.tvshgetall(index, "first_entity_knn"); + assertEquals("[0.12,0.23,0.56,0.67,0.78,0.89,0.01,0.89]", entity_string.get(VectorBuilderFactory.VECTOR_TAG)); assertEquals("sammy", entity_string.get("name")); - Map entity_byte = tairVectorCluster.tvshgetall(SafeEncoder.encode(index), SafeEncoder.encode("first_entity")); - assertEquals(SafeEncoder.encode("[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]"), entity_byte.get(SafeEncoder.encode(VectorBuilderFactory.VECTOR_TAG))); - assertEquals(SafeEncoder.encode("sammy"), entity_byte.get(SafeEncoder.encode("name"))); + Map entity_byte = tairVectorCluster.tvshgetall(SafeEncoder.encode(index), SafeEncoder.encode("first_entity_knn")); + assertEquals("[0.12,0.23,0.56,0.67,0.78,0.89,0.01,0.89]", SafeEncoder.encode(entity_byte.get(SafeEncoder.encode(VectorBuilderFactory.VECTOR_TAG)))); + assertEquals("sammy", SafeEncoder.encode(entity_byte.get(SafeEncoder.encode("name")))); } + + @Test public void tvs_hmgetall() { - tvs_del_entity("first_entity"); - tvs_del_entity("second_entity"); - tvs_hset("first_entity", "[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "name", "sammy"); - tvs_hset(SafeEncoder.encode("second_entity"), SafeEncoder.encode("[0.22, 0.33, 0.66, 0.77, 0.88, 0.89, 0.11, 0.89]"), SafeEncoder.encode("name"), SafeEncoder.encode("tiddy")); + check_index(dims, algorithm, method, index_params.toArray(new String[0])); + tvs_del_entity("first_entity_knn"); + tvs_del_entity("second_entity_knn"); + tvs_hset("first_entity_knn", "[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "name", "sammy"); + tvs_hset(SafeEncoder.encode("second_entity_knn"), SafeEncoder.encode("[0.22, 0.33, 0.66, 0.77, 0.88, 0.89, 0.11, 0.89]"), + SafeEncoder.encode("name"), SafeEncoder.encode("tiddy")); - List entity_string = tairVectorCluster.tvshmget(index, "first_entity", VectorBuilderFactory.VECTOR_TAG, "name"); - assertEquals("[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", entity_string.get(0)); + List entity_string = tairVectorCluster.tvshmget(index, "first_entity_knn", VectorBuilderFactory.VECTOR_TAG, "name"); + assertEquals("[0.12,0.23,0.56,0.67,0.78,0.89,0.01,0.89]", entity_string.get(0)); assertEquals("sammy", entity_string.get(1)); - List entity_byte = tairVectorCluster.tvshmget(SafeEncoder.encode(index), SafeEncoder.encode("first_entity"), SafeEncoder.encode(VectorBuilderFactory.VECTOR_TAG), SafeEncoder.encode("name")); - assertEquals(SafeEncoder.encode("[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]"), entity_byte.get(0)); - assertEquals(SafeEncoder.encode("sammy"), entity_byte.get(1)); + List entity_byte = tairVectorCluster.tvshmget(SafeEncoder.encode(index), SafeEncoder.encode("first_entity_knn"), + SafeEncoder.encode(VectorBuilderFactory.VECTOR_TAG), SafeEncoder.encode("name")); + assertEquals("[0.12,0.23,0.56,0.67,0.78,0.89,0.01,0.89]", SafeEncoder.encode(entity_byte.get(0))); + assertEquals("sammy", SafeEncoder.encode(entity_byte.get(1))); } @Test public void tvs_del() { - tvs_del_entity("first_entity"); - tvs_del_entity("second_entity"); + check_index(dims, algorithm, method, index_params.toArray(new String[0])); + tvs_del_entity("first_entity_knn"); + tvs_del_entity("second_entity_knn"); - tvs_hset("first_entity", "[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "name", "sammy"); - tvs_hset(SafeEncoder.encode("second_entity"), SafeEncoder.encode("[0.22, 0.33, 0.66, 0.77, 0.88, 0.89, 0.11, 0.89]"), SafeEncoder.encode("name"), SafeEncoder.encode("tiddy")); + tvs_hset("first_entity_knn", "[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "name", "sammy"); + tvs_hset(SafeEncoder.encode("second_entity_knn"), SafeEncoder.encode("[0.22, 0.33, 0.66, 0.77, 0.88, 0.89, 0.11, 0.89]"), + SafeEncoder.encode("name"), SafeEncoder.encode("tiddy")); - long count_string = tvs_del_entity("first_entity"); + long count_string = tvs_del_entity("first_entity_knn"); assertEquals(1, count_string); - long count_byte = tvs_del_entity(SafeEncoder.encode("second_entity")); + long count_byte = tvs_del_entity(SafeEncoder.encode("second_entity_knn")); assertEquals(1, count_byte); } @Test public void tvs_hdel() { - tvs_del_entity("first_entity"); - tvs_del_entity("second_entity"); + check_index(dims, algorithm, method, index_params.toArray(new String[0])); + tvs_del_entity("first_entity_knn"); + tvs_del_entity("second_entity_knn"); - tvs_hset("first_entity", "[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "name", "sammy"); - tvs_hset(SafeEncoder.encode("second_entity"), SafeEncoder.encode("[0.22, 0.33, 0.66, 0.77, 0.88, 0.89, 0.11, 0.89]"), SafeEncoder.encode("name"), SafeEncoder.encode("tiddy")); + tvs_hset("first_entity_knn", "[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "name", "sammy"); + tvs_hset(SafeEncoder.encode("second_entity_knn"), SafeEncoder.encode("[0.22, 0.33, 0.66, 0.77, 0.88, 0.89, 0.11, 0.89]"), + SafeEncoder.encode("name"), SafeEncoder.encode("tiddy")); - long count_string = tairVectorCluster.tvshdel(index, "first_entity", "name"); + long count_string = tairVectorCluster.tvshdel(index, "first_entity_knn", "name"); assertEquals(1, count_string); - Map entity_string = tairVectorCluster.tvshgetall(index, "first_entity"); + Map entity_string = tairVectorCluster.tvshgetall(index, "first_entity_knn"); assertTrue(entity_string.size() == 1 && (!entity_string.containsKey("name"))); - long count_byte = tairVectorCluster.tvshdel(SafeEncoder.encode(index), SafeEncoder.encode("second_entity"), SafeEncoder.encode(VectorBuilderFactory.VECTOR_TAG)); - assertEquals(1, count_byte); - Map entity_byte = tairVectorCluster.tvshgetall(index, "second_entity"); + long count_byte = tairVectorCluster.tvshdel(SafeEncoder.encode(index), SafeEncoder.encode("second_entity_knn"), + SafeEncoder.encode(VectorBuilderFactory.VECTOR_TAG)); + //assertEquals(1, count_byte); + Map entity_byte = tairVectorCluster.tvshgetall(index, "second_entity_knn"); assertTrue(entity_byte.size() == 1 && (!entity_byte.containsKey(VectorBuilderFactory.VECTOR_TAG))); } @Test public void tvs_scan() { - tvs_del_entity("first_entity"); - tvs_del_entity("second_entity"); + check_index(dims, algorithm, method, index_params.toArray(new String[0])); + tvs_del_entity("first_entity_knn"); + tvs_hset("first_entity_knn", "[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "name", "sammy"); - tvs_hset("first_entity", "[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "name", "sammy"); - tvs_hset(SafeEncoder.encode("second_entity"), SafeEncoder.encode("[0.22, 0.33, 0.66, 0.77, 0.88, 0.89, 0.11, 0.89]"), SafeEncoder.encode("name"), SafeEncoder.encode("tiddy")); + tvs_del_entity("five_entity_knn"); + tvs_hset(SafeEncoder.encode("five_entity_knn"), SafeEncoder.encode("[0.22, 0.33, 0.66, 0.77, 0.88, 0.89, 0.11, 0.89]"), + SafeEncoder.encode("name"), SafeEncoder.encode("tiddy")); long cursor = 0; - HscanParams exhscanParams = new HscanParams(); - exhscanParams.count(1); - exhscanParams.match("*entit*"); - ScanResult result_string = tairVectorCluster.tvsscan(index, cursor, exhscanParams); + HscanParams hscanParams = new HscanParams(); + hscanParams.count(1); + hscanParams.match("*entit*"); + ScanResult result_string = tairVectorCluster.tvsscan(index, cursor, hscanParams); assert (result_string.getResult().size() >= 1); - ScanResult entity_byte = tairVectorCluster.tvsscan(SafeEncoder.encode(index), cursor, exhscanParams); + ScanResult entity_byte = tairVectorCluster.tvsscan(SafeEncoder.encode(index), cursor, hscanParams); assert (entity_byte.getResult().size() >= 1); } - - @Test - public void tvs_knnsearch() { - tvs_del_entity("first_entity"); - tvs_del_entity("second_entity"); - - tvs_hset("first_entity", "[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "name", "sammy"); - tvs_hset(SafeEncoder.encode("second_entity"), SafeEncoder.encode("[0.22, 0.33, 0.66, 0.77, 0.88, 0.89, 0.11, 0.89]"), SafeEncoder.encode("name"), SafeEncoder.encode("tiddy")); - - long topn = 10L; - VectorBuilderFactory.Knn result_string = tairVectorCluster.tvsknnsearch(index, topn, "[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]"); - assertEquals(2, result_string.getKnnResults().size()); - - VectorBuilderFactory.Knn entity_byte = tairVectorCluster.tvsknnsearch(SafeEncoder.encode(index), topn, SafeEncoder.encode("[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]")); - assertEquals(2, entity_byte.getKnnResults().size()); - } - - @Test - public void tvs_knnsearch_filter() { - tvs_del_entity("first_entity"); - tvs_del_entity("second_entity"); - - tvs_hset("first_entity", "[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "name", "sammy"); - tvs_hset(SafeEncoder.encode("second_entity"), SafeEncoder.encode("[0.22, 0.33, 0.66, 0.77, 0.88, 0.89, 0.11, 0.89]"), SafeEncoder.encode("name"), SafeEncoder.encode("tiddy")); - - long topn = 10L; - VectorBuilderFactory.Knn result_string = tairVectorCluster.tvsknnsearchfilter(index, topn, "[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "name == \"sammy\""); - assertEquals(2, result_string.getKnnResults().size()); - - VectorBuilderFactory.Knn entity_byte = tairVectorCluster.tvsknnsearchfilter(SafeEncoder.encode(index), topn, SafeEncoder.encode("[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]"), SafeEncoder.encode("name == \"sammy\"")); - assertEquals(2, entity_byte.getKnnResults().size()); - } - - @Test - public void tvs_knnsearch_with_params() { - tvs_del_entity("first_entity"); - tvs_del_entity("second_entity"); - - tvs_hset("first_entity", "[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "name", "sammy"); - tvs_hset(SafeEncoder.encode("second_entity"), SafeEncoder.encode("[0.22, 0.33, 0.66, 0.77, 0.88, 0.89, 0.11, 0.89]"), - SafeEncoder.encode("name"), SafeEncoder.encode("tiddy")); - - long topn = 10L; - VectorBuilderFactory.Knn result_string = tairVectorCluster.tvsknnsearch(index, topn, - "[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", ef_params.toArray(new String[0])); - assertEquals(2, result_string.getKnnResults().size()); - - VectorBuilderFactory.Knn entity_byte = tairVectorCluster.tvsknnsearch(SafeEncoder.encode(index), topn, - SafeEncoder.encode("[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]"), - SafeEncoder.encodeMany(ef_params.toArray(new String[0]))); - assertEquals(2, entity_byte.getKnnResults().size()); - } - - @Test - public void tvs_mknnsearch() { - tvs_del_entity("first_entity"); - tvs_del_entity("second_entity"); - - tvs_hset("first_entity", "[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "name", "sammy"); - tvs_hset(SafeEncoder.encode("second_entity"), SafeEncoder.encode("[0.22, 0.33, 0.66, 0.77, 0.88, 0.89, 0.11, 0.89]"), SafeEncoder.encode("name"), SafeEncoder.encode("tiddy")); - - long topn = 10L; - List vectors = Arrays.asList("[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "[0.22, 0.33, 0.66, 0.77, 0.88, 0.89, 0.11, 0.89]"); - Collection> result_string = tairVectorCluster.tvsmknnsearch(index, topn, vectors); - assertEquals(2, result_string.size()); - result_string.forEach(one -> System.out.printf("string: %s\n", one.toString())); - - - Collection> entity_byte = tairVectorCluster.tvsmknnsearch(SafeEncoder.encode(index), topn, vectors.stream().map(item -> SafeEncoder.encode(item)).collect(Collectors.toList())); - assertEquals(2, entity_byte.size()); - result_string.forEach(one -> System.out.printf("byte: %s\n", one.toString())); - } - - @Test - public void tvs_mknnsearch_filter() { - tvs_del_entity("first_entity"); - tvs_del_entity("second_entity"); - - tvs_hset("first_entity", "[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "name", "sammy"); - tvs_hset(SafeEncoder.encode("second_entity"), SafeEncoder.encode("[0.22, 0.33, 0.66, 0.77, 0.88, 0.89, 0.11, 0.89]"), SafeEncoder.encode("name"), SafeEncoder.encode("tiddy")); - - long topn = 10L; - List vectors = Arrays.asList("[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "[0.22, 0.33, 0.66, 0.77, 0.88, 0.89, 0.11, 0.89]"); - String pattern = "name == \"no-sammy\""; - Collection> result_string = tairVectorCluster.tvsmknnsearchfilter(index, topn, vectors, pattern); - assertEquals(2, result_string.size()); - result_string.forEach(one -> System.out.printf("string: %s\n", one.toString())); - - - Collection> entity_byte = tairVectorCluster.tvsmknnsearchfilter(SafeEncoder.encode(index), topn, vectors.stream().map(item -> SafeEncoder.encode(item)).collect(Collectors.toList()), SafeEncoder.encode(pattern)); - assertEquals(2, entity_byte.size()); - result_string.forEach(one -> System.out.printf("byte: %s\n", one.toString())); - } - - @Test - public void tvs_mknnsearch_with_params() { - tvs_del_entity("first_entity"); - tvs_del_entity("second_entity"); - - tvs_hset("first_entity", "[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "name", "sammy"); - tvs_hset(SafeEncoder.encode("second_entity"), SafeEncoder.encode("[0.22, 0.33, 0.66, 0.77, 0.88, 0.89, 0.11, 0.89]"), - SafeEncoder.encode("name"), SafeEncoder.encode("tiddy")); - - long topn = 10L; - List vectors = Arrays.asList("[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "[0.22, 0.33, 0.66, 0.77, 0.88, 0.89, 0.11, 0.89]"); - Collection> result_string = tairVectorCluster.tvsmknnsearch(index, topn, vectors, ef_params.toArray(new String[0])); - assertEquals(2, result_string.size()); - result_string.forEach(one -> System.out.printf("string: %s\n", one.toString())); - - - Collection> entity_byte = tairVectorCluster.tvsmknnsearch(SafeEncoder.encode(index), topn, - vectors.stream().map(item -> SafeEncoder.encode(item)).collect(Collectors.toList()), - SafeEncoder.encodeMany(ef_params.toArray(new String[0]))); - assertEquals(2, entity_byte.size()); - result_string.forEach(one -> System.out.printf("byte: %s\n", one.toString())); - } } diff --git a/src/test/java/com/aliyun/tair/tests/tairvector/TairVectorPipelineTest.java b/src/test/java/com/aliyun/tair/tests/tairvector/TairVectorPipelineTest.java index e2102ae..4d1b799 100644 --- a/src/test/java/com/aliyun/tair/tests/tairvector/TairVectorPipelineTest.java +++ b/src/test/java/com/aliyun/tair/tests/tairvector/TairVectorPipelineTest.java @@ -1,12 +1,5 @@ package com.aliyun.tair.tests.tairvector; -import java.util.Arrays; -import java.util.Collection; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; - import com.aliyun.tair.tairvector.factory.VectorBuilderFactory; import com.aliyun.tair.tairvector.params.DistanceMethod; import com.aliyun.tair.tairvector.params.HscanParams; @@ -15,9 +8,10 @@ import redis.clients.jedis.ScanResult; import redis.clients.jedis.util.SafeEncoder; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotEquals; -import static org.junit.Assert.assertTrue; +import java.util.*; +import java.util.stream.Collectors; + +import static org.junit.Assert.*; public class TairVectorPipelineTest extends TairVectorTestBase { diff --git a/src/test/java/com/aliyun/tair/tests/tairvector/TairVectorShardTest.java b/src/test/java/com/aliyun/tair/tests/tairvector/TairVectorShardTest.java new file mode 100644 index 0000000..f1e6d86 --- /dev/null +++ b/src/test/java/com/aliyun/tair/tests/tairvector/TairVectorShardTest.java @@ -0,0 +1,439 @@ +package com.aliyun.tair.tests.tairvector; + +import com.aliyun.tair.tairvector.factory.VectorBuilderFactory; +import com.aliyun.tair.tairvector.params.DistanceMethod; +import com.aliyun.tair.tairvector.params.HscanParams; +import com.aliyun.tair.tairvector.params.IndexAlgorithm; +import org.junit.Test; +import redis.clients.jedis.ScanResult; +import redis.clients.jedis.util.SafeEncoder; + +import java.util.*; +import java.util.stream.Collectors; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public class TairVectorShardTest extends TairVectorTestBase { + + final String index = "default_index"; + final int dims = 8; + final IndexAlgorithm algorithm = IndexAlgorithm.HNSW; + final DistanceMethod method = DistanceMethod.IP; + final long dbid = 2; + final List index_params = Arrays.asList("ef_construct", "100", "M", "16"); + final List index_params_with_dataType = Arrays.asList("ef_construct", "100", "M", "16","data_type","BINARY"); + final List ef_params = Arrays.asList("ef_search", "100"); + + private String get_shard_index_name(String index_name){ + return index_name+"_0"; + } + /** + * 127.0.0.1:6379> tvs.createindex default_index 8 HNSW IP + */ + private void tvs_create_index(int dims, IndexAlgorithm algorithm, DistanceMethod method, final String... attr) { + tairVectorDistribute.tvsdelindex(index); + assertEquals("OK", tairVectorDistribute.tvscreateindex(index, dims, algorithm, method, attr)); + } + + private void check_index(int dims, IndexAlgorithm algorithm, DistanceMethod method, final String... attr) { + List> objs = tairVectorDistribute.tvsgetindex(index); + if (!objs.isEmpty()) { + long result = tairVectorDistribute.tvsdelindex(index); + assertEquals(result, 2); + } + assertEquals("OK", tairVectorDistribute.tvscreateindex(index, dims, algorithm, method, attr)); + } + + private void tvs_hset(final String entityid, final String vector, final String param_k, final String param_v) { + long result = tairVectorDistribute.tvshset(index, entityid, vector, param_k, param_v); + assertEquals(result, 2); + } + + private void tvs_hset(byte[] entityid, byte[] vector, byte[] param_k, byte[] param_v) { + long result = tairVectorDistribute.tvshset(SafeEncoder.encode(index), entityid, vector, param_k, param_v); + assertTrue(result <= 2); + } + + private long tvs_del_entity(String entity) { + return tairVectorDistribute.tvsdel(index, entity); + } + + private long tvs_del_entity(byte[] entity) { + return tairVectorDistribute.tvsdel(SafeEncoder.encode(index), entity); + } + + @Test + public void tvs_create_index() { + tvs_del_index(); + assertEquals("OK", tairVectorDistribute.tvscreateindex(index, dims, algorithm, method, index_params.toArray(new String[0]))); + try { + tairVectorDistribute.tvscreateindex(SafeEncoder.encode(index), dims, algorithm, method); + } catch (Exception e) { + assertEquals(e.getMessage(), "ERR duplicated index key"); + } + } + + + @Test + public void tvs_create_index_with_datatype() { + tvs_del_index(); + try { + tairVectorDistribute.tvscreateindex(index, dims, algorithm, method, index_params_with_dataType.toArray(new String[0])); + }catch (Exception e){ + assertEquals(e.getMessage(), "ERR index parameters invalid"); + } + assertEquals("OK", tairVectorDistribute.tvscreateindex(index, dims, algorithm, DistanceMethod.JACCARD, index_params_with_dataType.toArray(new String[0]))); + try { + tairVectorDistribute.tvscreateindex(SafeEncoder.encode(index), dims, algorithm, method); + } catch (Exception e) { + assertEquals(e.getMessage(), "ERR duplicated index key"); + } + } + + @Test + public void tvs_create_index_withoption_args() { + tvs_del_index(); + assertEquals("OK", tairVectorDistribute.tvscreateindex(index, dims, algorithm, method, + "ef_construct", "50", "M", "20")); + List> schema = tairVectorDistribute.tvsgetindex(index); + assertEquals(String.valueOf(50), schema.get(0).get("ef_construct")); + assertEquals(String.valueOf(20), schema.get(0).get("M")); + } + + /** + * 127.0.0.1:6379> tvs.getindex default_index + */ + @Test + public void tvs_get_index() { + tvs_create_index(dims, algorithm, method, index_params.toArray(new String[0])); + + List> schema = tairVectorDistribute.tvsgetindex(index); + assertEquals(get_shard_index_name(index), schema.get(0).get("index_name")); + assertEquals(algorithm.name(), schema.get(0).get("algorithm")); + assertEquals(method.name(), schema.get(0).get("distance_method")); + assertEquals(String.valueOf(0), schema.get(0).get("data_count")); + + + List> schema_bytecode = tairVectorDistribute.tvsgetindex(SafeEncoder.encode(index)); + Iterator> entries = schema_bytecode.get(0).entrySet().iterator(); + while (entries.hasNext()) { + Map.Entry entry = entries.next(); + assertEquals(schema.get(0).get(SafeEncoder.encode(entry.getKey())), SafeEncoder.encode(entry.getValue())); + } + } + + @Test + public void tvs_del_index() { + check_index(dims, algorithm, method, index_params.toArray(new String[0])); + + List> schema = tairVectorDistribute.tvsgetindex(index); + assertEquals(get_shard_index_name(index), schema.get(0).get("index_name")); + assertEquals(algorithm.name(), schema.get(0).get("algorithm")); + assertEquals(method.name(), schema.get(0).get("distance_method")); + assertEquals(String.valueOf(0), schema.get(0).get("data_count")); + + long result = tairVectorDistribute.tvsdelindex(index); + assertEquals(result, 2); + long result_byte = tairVectorDistribute.tvsdelindex(SafeEncoder.encode(index)); + assertEquals(result_byte, 0); + } + + @Test + public void tvs_hset_data_bin() { + check_index(dims, algorithm, DistanceMethod.JACCARD, index_params_with_dataType.toArray(new String[0])); + tvs_del_entity("fourth_entity_knn"); + tvs_hset("fourth_entity_knn", "[1,1,0,0,1,0,1,0]", "name", "sammy"); + tvs_del_entity("ten_entity_knn"); + tvs_hset(SafeEncoder.encode("ten_entity_knn"), SafeEncoder.encode("[1,1,0,0,1,0,1,0]"), + SafeEncoder.encode("name"), SafeEncoder.encode("tiddy")); + } + + @Test + public void tvs_hgetall_data_bin() { + tvs_del_entity("first_entity_knn"); + tvs_del_entity("second_entity_knn"); + tvs_hset("first_entity_knn", "[1,1,1,1,0,0,0,0]", "name", "sammy"); + tvs_hset(SafeEncoder.encode("second_entity_knn"), SafeEncoder.encode("[1,1,1,1,0,0,0,0]"), + SafeEncoder.encode("name"), SafeEncoder.encode("tiddy")); + + Map entity_string = tairVectorDistribute.tvshgetall(index, "first_entity_knn"); + assertEquals("[1,1,1,1,0,0,0,0]", entity_string.get(VectorBuilderFactory.VECTOR_TAG)); + assertEquals("sammy", entity_string.get("name")); + + Map entity_byte = tairVectorDistribute.tvshgetall(SafeEncoder.encode(index), SafeEncoder.encode("first_entity_knn")); + assertEquals("[1,1,1,1,0,0,0,0]", SafeEncoder.encode(entity_byte.get(SafeEncoder.encode(VectorBuilderFactory.VECTOR_TAG)))); + assertEquals("sammy", SafeEncoder.encode(entity_byte.get(SafeEncoder.encode("name")))); + } + + @Test + public void tvs_hset() { + check_index(dims, algorithm, method, index_params.toArray(new String[0])); + tvs_del_entity("fourth_entity_knn"); + tvs_hset("fourth_entity_knn", "[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "name", "sammy"); + tvs_del_entity("ten_entity_knn"); + tvs_hset(SafeEncoder.encode("ten_entity_knn"), SafeEncoder.encode("[0.22, 0.33, 0.66, 0.77, 0.88, 0.89, 0.11, 0.89]"), + SafeEncoder.encode("name"), SafeEncoder.encode("tiddy")); + } + + @Test + public void tvs_hgetall() { + tvs_del_entity("first_entity_knn"); + tvs_del_entity("second_entity_knn"); + tvs_hset("first_entity_knn", "[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "name", "sammy"); + tvs_hset(SafeEncoder.encode("second_entity_knn"), SafeEncoder.encode("[0.22, 0.33, 0.66, 0.77, 0.88, 0.89, 0.11, 0.89]"), + SafeEncoder.encode("name"), SafeEncoder.encode("tiddy")); + + Map entity_string = tairVectorDistribute.tvshgetall(index, "first_entity_knn"); + assertEquals("[0.12,0.23,0.56,0.67,0.78,0.89,0.01,0.89]", entity_string.get(VectorBuilderFactory.VECTOR_TAG)); + assertEquals("sammy", entity_string.get("name")); + + Map entity_byte = tairVectorDistribute.tvshgetall(SafeEncoder.encode(index), SafeEncoder.encode("first_entity_knn")); + assertEquals("[0.12,0.23,0.56,0.67,0.78,0.89,0.01,0.89]", SafeEncoder.encode(entity_byte.get(SafeEncoder.encode(VectorBuilderFactory.VECTOR_TAG)))); + assertEquals("sammy", SafeEncoder.encode(entity_byte.get(SafeEncoder.encode("name")))); + } + + @Test + public void tvs_hmgetall() { + check_index(dims, algorithm, method, index_params.toArray(new String[0])); + tvs_del_entity("first_entity_knn"); + tvs_del_entity("second_entity_knn"); + tvs_hset("first_entity_knn", "[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "name", "sammy"); + tvs_hset(SafeEncoder.encode("second_entity_knn"), SafeEncoder.encode("[0.22, 0.33, 0.66, 0.77, 0.88, 0.89, 0.11, 0.89]"), + SafeEncoder.encode("name"), SafeEncoder.encode("tiddy")); + + List entity_string = tairVectorDistribute.tvshmget(index, "first_entity_knn", VectorBuilderFactory.VECTOR_TAG, "name"); + assertEquals("[0.12,0.23,0.56,0.67,0.78,0.89,0.01,0.89]", entity_string.get(0)); + assertEquals("sammy", entity_string.get(1)); + + List entity_byte = tairVectorDistribute.tvshmget(SafeEncoder.encode(index), SafeEncoder.encode("first_entity_knn"), + SafeEncoder.encode(VectorBuilderFactory.VECTOR_TAG), SafeEncoder.encode("name")); + assertEquals("[0.12,0.23,0.56,0.67,0.78,0.89,0.01,0.89]", SafeEncoder.encode(entity_byte.get(0))); + assertEquals("sammy", SafeEncoder.encode(entity_byte.get(1))); + } + + @Test + public void tvs_del() { + check_index(dims, algorithm, method, index_params.toArray(new String[0])); + tvs_del_entity("first_entity_knn"); + tvs_del_entity("second_entity_knn"); + + tvs_hset("first_entity_knn", "[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "name", "sammy"); + tvs_hset(SafeEncoder.encode("second_entity_knn"), SafeEncoder.encode("[0.22, 0.33, 0.66, 0.77, 0.88, 0.89, 0.11, 0.89]"), + SafeEncoder.encode("name"), SafeEncoder.encode("tiddy")); + + long count_string = tvs_del_entity("first_entity_knn"); + assertEquals(1, count_string); + + long count_byte = tvs_del_entity(SafeEncoder.encode("second_entity_knn")); + assertEquals(1, count_byte); + } + + @Test + public void tvs_hdel() { + check_index(dims, algorithm, method, index_params.toArray(new String[0])); + tvs_del_entity("first_entity_knn"); + tvs_del_entity("second_entity_knn"); + + tvs_hset("first_entity_knn", "[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "name", "sammy"); + tvs_hset(SafeEncoder.encode("second_entity_knn"), SafeEncoder.encode("[0.22, 0.33, 0.66, 0.77, 0.88, 0.89, 0.11, 0.89]"), + SafeEncoder.encode("name"), SafeEncoder.encode("tiddy")); + + long count_string = tairVectorDistribute.tvshdel(index, "first_entity_knn", "name"); + assertEquals(1, count_string); + Map entity_string = tairVectorDistribute.tvshgetall(index, "first_entity_knn"); + assertTrue(entity_string.size() == 1 && (!entity_string.containsKey("name"))); + + long count_byte = tairVectorDistribute.tvshdel(SafeEncoder.encode(index), SafeEncoder.encode("second_entity_knn"), + SafeEncoder.encode(VectorBuilderFactory.VECTOR_TAG)); + //assertEquals(1, count_byte); + Map entity_byte = tairVectorDistribute.tvshgetall(index, "second_entity_knn"); + assertTrue(entity_byte.size() == 1 && (!entity_byte.containsKey(VectorBuilderFactory.VECTOR_TAG))); + } + + @Test + public void tvs_scan() { + check_index(dims, algorithm, method, index_params.toArray(new String[0])); + tvs_del_entity("first_entity_knn"); + tvs_hset("first_entity_knn", "[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "name", "sammy"); + + tvs_del_entity("five_entity_knn"); + tvs_hset(SafeEncoder.encode("five_entity_knn"), SafeEncoder.encode("[0.22, 0.33, 0.66, 0.77, 0.88, 0.89, 0.11, 0.89]"), + SafeEncoder.encode("name"), SafeEncoder.encode("tiddy")); + + long cursor = 0; + HscanParams hscanParams = new HscanParams(); + hscanParams.count(1); + hscanParams.match("*entit*"); + List> result_string = tairVectorDistribute.tvsscan(index, cursor, hscanParams); + assert (result_string.get(0).getResult().size() >= 1); + + List> entity_byte = tairVectorDistribute.tvsscan(SafeEncoder.encode(index), cursor, hscanParams); + assert (entity_byte.get(0).getResult().size() >= 1); + } + + @Test + public void tvs_knnsearch() { + check_index(dims, algorithm, method, index_params.toArray(new String[0])); + tvs_del_entity("first_entity_knn"); + tvs_del_entity(SafeEncoder.encode("second_entity_knn")); + + tvs_hset("first_entity_knn", "[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "name", "sammy"); + tvs_hset(SafeEncoder.encode("second_entity_knn"), SafeEncoder.encode("[0.22, 0.33, 0.66, 0.77, 0.88, 0.89, 0.11, 0.89]"), + SafeEncoder.encode("name"), SafeEncoder.encode("tiddy")); + + long topn = 2L; + VectorBuilderFactory.Knn result_string = tairVectorDistribute.tvsknnsearch(index, topn, "[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]"); + assertEquals(2, result_string.getKnnResults().size()); + + VectorBuilderFactory.Knn entity_byte = tairVectorDistribute.tvsknnsearch(SafeEncoder.encode(index), topn, + SafeEncoder.encode("[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]")); + assertEquals(2, entity_byte.getKnnResults().size()); + } + + @Test + public void tvs_knnsearch_with_databin() { + check_index(dims, algorithm, DistanceMethod.JACCARD, index_params_with_dataType.toArray(new String[0])); + tvs_del_entity("first_entity_knn"); + tvs_del_entity(SafeEncoder.encode("second_entity_knn")); + + tvs_hset("first_entity_knn", "[1,1,1,1,0,0,0,0]", "name", "sammy"); + tvs_hset(SafeEncoder.encode("second_entity_knn"), SafeEncoder.encode("[1,1,1,1,0,0,0,0]"), + SafeEncoder.encode("name"), SafeEncoder.encode("tiddy")); + + long topn = 2L; + VectorBuilderFactory.Knn result_string = tairVectorDistribute.tvsknnsearch(index, topn, "[1,1,1,1,0,0,0,0]"); + assertEquals(2, result_string.getKnnResults().size()); + + VectorBuilderFactory.Knn entity_byte = tairVectorDistribute.tvsknnsearch(SafeEncoder.encode(index), topn, + SafeEncoder.encode("[1,1,1,1,0,0,0,0]")); + assertEquals(2, entity_byte.getKnnResults().size()); + } + + @Test + public void tvs_knnsearch_with_filter() { + tairVectorDistribute.tvsdelindex(SafeEncoder.encode(index)); + + check_index(dims, algorithm, method, index_params.toArray(new String[0])); + tvs_del_entity("first_entity_knn"); + tvs_del_entity(SafeEncoder.encode("second_entity_knn")); + + tvs_hset("first_entity_knn", "[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "name", "sammy"); + tvs_hset(SafeEncoder.encode("second_entity_knn"), SafeEncoder.encode("[0.22, 0.33, 0.66, 0.77, 0.88, 0.89, 0.11, 0.89]"), + SafeEncoder.encode("name"), SafeEncoder.encode("tiddy")); + + long topn = 10L; + VectorBuilderFactory.Knn result_string = tairVectorDistribute.tvsknnsearchfilter(index, topn, + "[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "name == \"sammy\""); + assertEquals(1, result_string.getKnnResults().size()); + + VectorBuilderFactory.Knn entity_byte = tairVectorDistribute.tvsknnsearchfilter(SafeEncoder.encode(index), topn, + SafeEncoder.encode("[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]"), SafeEncoder.encode("name != \"sammy\"")); + assertEquals(1, entity_byte.getKnnResults().size()); + } + + @Test + public void tvs_knnsearch_with_params() { + check_index(dims, algorithm, method, index_params.toArray(new String[0])); + tvs_del_entity("first_entity_knn"); + tvs_del_entity(SafeEncoder.encode("second_entity_knn")); + + tvs_hset("first_entity_knn", "[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "name", "sammy"); + tvs_hset(SafeEncoder.encode("second_entity_knn"), SafeEncoder.encode("[0.22, 0.33, 0.66, 0.77, 0.88, 0.89, 0.11, 0.89]"), + SafeEncoder.encode("name"), SafeEncoder.encode("tiddy")); + + long topn = 2L; + VectorBuilderFactory.Knn result_string = tairVectorDistribute.tvsknnsearch(index, topn, + "[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", ef_params.toArray(new String[0])); + assertEquals(2, result_string.getKnnResults().size()); + + VectorBuilderFactory.Knn entity_byte = tairVectorDistribute.tvsknnsearch(SafeEncoder.encode(index), topn, + SafeEncoder.encode("[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]"), + SafeEncoder.encodeMany(ef_params.toArray(new String[0]))); + assertEquals(2, entity_byte.getKnnResults().size()); + } + + @Test + public void tvs_mknnsearch() { + check_index(dims, algorithm, method, index_params.toArray(new String[0])); + tvs_del_entity("first_entity_knn"); + tvs_del_entity("second_entity_knn"); + + tvs_hset("first_entity_knn", "[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "name", "sammy"); + tvs_hset(SafeEncoder.encode("second_entity_knn"), SafeEncoder.encode("[0.22, 0.33, 0.66, 0.77, 0.88, 0.89, 0.11, 0.89]"), + SafeEncoder.encode("name"), SafeEncoder.encode("tiddy")); + + long topn = 2L; + List vectors = Arrays.asList("[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "[0.22, 0.33, 0.66, 0.77, 0.88, 0.89, 0.11, 0.89]"); + Collection> result_string = tairVectorDistribute.tvsmknnsearch(index, topn, vectors); + result_string.forEach(result -> { + assertEquals(2, result.getKnnResults().size()); + }); + result_string.forEach(one -> System.out.printf("string: %s\n", one.toString())); + + + Collection> result_byte = tairVectorDistribute.tvsmknnsearch(SafeEncoder.encode(index), topn, + vectors.stream().map(item -> SafeEncoder.encode(item)).collect(Collectors.toList())); + result_byte.forEach(result -> { + assertEquals(2, result.getKnnResults().size()); + }); + result_string.forEach(one -> System.out.printf("byte: %s\n", one.toString())); + } + + @Test + public void tvs_mknnsearch_filter() { + check_index(dims, algorithm, method, index_params.toArray(new String[0])); + tvs_del_entity("first_entity_knn"); + tvs_del_entity("second_entity_knn"); + + tvs_hset("first_entity_knn", "[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "name", "sammy"); + tvs_hset(SafeEncoder.encode("second_entity_knn"), SafeEncoder.encode("[0.22, 0.33, 0.66, 0.77, 0.88, 0.89, 0.11, 0.89]"), + SafeEncoder.encode("name"), SafeEncoder.encode("tiddy")); + + long topn = 1L; + List vectors = Arrays.asList("[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "[0.22, 0.33, 0.66, 0.77, 0.88, 0.89, 0.11, 0.89]"); + String pattern = "name == \"no-sammy\""; + Collection> result_string = tairVectorDistribute.tvsmknnsearchfilter(index, topn, vectors, pattern); + result_string.forEach(result -> { + assertEquals(0, result.getKnnResults().size()); + }); + result_string.forEach(one -> System.out.printf("string: %s\n", one.toString())); + + + Collection> result_byte = tairVectorDistribute.tvsmknnsearchfilter(SafeEncoder.encode(index), + topn, vectors.stream().map(item -> SafeEncoder.encode(item)).collect(Collectors.toList()), SafeEncoder.encode(pattern)); + result_byte.forEach(result -> { + assertEquals(0, result.getKnnResults().size()); + }); + result_string.forEach(one -> System.out.printf("byte: %s\n", one.toString())); + } + + @Test + public void tvs_mknnsearch_with_params() { + check_index(dims, algorithm, method, index_params.toArray(new String[0])); + tvs_del_entity("first_entity_knn"); + tvs_del_entity("second_entity_knn"); + + tvs_hset("first_entity_knn", "[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "name", "sammy"); + tvs_hset(SafeEncoder.encode("second_entity_knn"), SafeEncoder.encode("[0.22, 0.33, 0.66, 0.77, 0.88, 0.89, 0.11, 0.89]"), + SafeEncoder.encode("name"), SafeEncoder.encode("tiddy")); + + long topn = 2L; + List vectors = Arrays.asList("[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]", "[0.22, 0.33, 0.66, 0.77, 0.88, 0.89, 0.11, 0.89]"); + Collection> result_string = tairVectorDistribute.tvsmknnsearch(index, topn, vectors, + ef_params.toArray(new String[0])); + result_string.forEach(result -> { + assertEquals(2, result.getKnnResults().size()); + }); + result_string.forEach(one -> System.out.printf("string: %s\n", one.toString())); + + + Collection> result_byte = tairVectorDistribute.tvsmknnsearch(SafeEncoder.encode(index), topn, + vectors.stream().map(item -> SafeEncoder.encode(item)).collect(Collectors.toList()), + SafeEncoder.encodeMany(ef_params.toArray(new String[0]))); + result_byte.forEach(result -> { + assertEquals(2, result.getKnnResults().size()); + }); + result_string.forEach(one -> System.out.printf("byte: %s\n", one.toString())); + } +} diff --git a/src/test/java/com/aliyun/tair/tests/tairvector/TairVectorTest.java b/src/test/java/com/aliyun/tair/tests/tairvector/TairVectorTest.java index c084a70..2a85a61 100644 --- a/src/test/java/com/aliyun/tair/tests/tairvector/TairVectorTest.java +++ b/src/test/java/com/aliyun/tair/tests/tairvector/TairVectorTest.java @@ -1,12 +1,5 @@ package com.aliyun.tair.tests.tairvector; -import java.util.Arrays; -import java.util.Collection; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; - import com.aliyun.tair.tairvector.factory.VectorBuilderFactory; import com.aliyun.tair.tairvector.params.DistanceMethod; import com.aliyun.tair.tairvector.params.HscanParams; @@ -15,6 +8,9 @@ import redis.clients.jedis.ScanResult; import redis.clients.jedis.util.SafeEncoder; +import java.util.*; +import java.util.stream.Collectors; + import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -309,6 +305,7 @@ public void tvs_knnsearch() { SafeEncoder.encode("[0.12, 0.23, 0.56, 0.67, 0.78, 0.89, 0.01, 0.89]")); assertEquals(2, entity_byte.getKnnResults().size()); } + @Test public void tvs_knnsearch_with_databin() { check_index(dims, algorithm, DistanceMethod.JACCARD, index_params_with_dataType.toArray(new String[0])); diff --git a/src/test/java/com/aliyun/tair/tests/tairvector/TairVectorTestBase.java b/src/test/java/com/aliyun/tair/tests/tairvector/TairVectorTestBase.java index 2dc2ad4..a00f25d 100644 --- a/src/test/java/com/aliyun/tair/tests/tairvector/TairVectorTestBase.java +++ b/src/test/java/com/aliyun/tair/tests/tairvector/TairVectorTestBase.java @@ -1,27 +1,51 @@ package com.aliyun.tair.tests.tairvector; -import java.util.List; - import com.aliyun.tair.tairvector.TairVector; import com.aliyun.tair.tairvector.TairVectorCluster; import com.aliyun.tair.tairvector.TairVectorPipeline; +import com.aliyun.tair.tairvector.TairVectorShard; import com.aliyun.tair.tests.TestBase; import org.junit.AfterClass; import org.junit.BeforeClass; import redis.clients.jedis.ScanResult; +import java.util.List; + import static com.aliyun.tair.tests.AssertUtil.assertEquals; public class TairVectorTestBase extends TestBase { public static TairVector tairVector; public static TairVectorPipeline tairVectorPipeline; public static TairVectorCluster tairVectorCluster; + public static TairVectorShard tairVectorDistribute; @BeforeClass public static void setUp() { - tairVector = new TairVector(jedisPool); + tairVector = new TairVector(jedis); tairVectorPipeline = new TairVectorPipeline(); tairVectorPipeline.setClient(jedis.getClient()); tairVectorCluster = new TairVectorCluster(jedisCluster); + tairVectorDistribute = new TairVectorShard(tairVectorCluster,2); + + } + + @AfterClass + public static void closeDown() { + tairVector.quit(); + tairVectorCluster.quit(); + tairVectorDistribute.quit(); + } + + public static void assertLongListEquals(List expected, List actual) { + assertEquals(expected.size(), actual.size()); + for (int n = 0; n < expected.size(); n++) { + assertEquals(expected.get(n), actual.get(n)); + } + } + + public static void assertScanResultEquals(List expected, ScanResult actual) { + for (int n = 0; n < expected.size(); n++) { + assertEquals(expected.get(n), actual.getResult().get(n)); + } } }