Skip to content

TairVector: support cluster #22

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

<groupId>com.aliyun.tair</groupId>
<artifactId>alibabacloud-tairjedis-sdk</artifactId>
<version>2.4.0-SNAPSHOT</version>
<version>3.0.4</version>
<packaging>jar</packaging>

<name>alibabacloud-tairjedis-sdk</name>
Expand Down
317 changes: 107 additions & 210 deletions src/main/java/com/aliyun/tair/tairvector/TairVector.java

Large diffs are not rendered by default.

65 changes: 34 additions & 31 deletions src/main/java/com/aliyun/tair/tairvector/TairVectorCluster.java
Original file line number Diff line number Diff line change
@@ -1,14 +1,6 @@
package com.aliyun.tair.tairvector;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import com.aliyun.tair.ModuleCommand;
import com.aliyun.tair.tairhash.factory.HashBuilderFactory;
import com.aliyun.tair.tairvector.factory.VectorBuilderFactory;
import com.aliyun.tair.tairvector.params.DistanceMethod;
import com.aliyun.tair.tairvector.params.HscanParams;
Expand All @@ -19,15 +11,19 @@
import redis.clients.jedis.ScanResult;
import redis.clients.jedis.util.SafeEncoder;

import java.util.*;
import java.util.stream.Collectors;

import static redis.clients.jedis.Protocol.toByteArray;

public class TairVectorCluster {
public class TairVectorCluster implements VectorShard {
private JedisCluster jc;

public TairVectorCluster(JedisCluster jc) {
this.jc = jc;
}

@Override
public void quit() {
if (jc != null) {
jc.close();
Expand All @@ -46,11 +42,13 @@ public void quit() {
* @param attrs other columns, optional
* @return Success: +OK; Fail: error
*/
@Override
public String tvscreateindex(final String index, int dims, IndexAlgorithm algorithm, DistanceMethod method, final String... attrs) {
Object obj = jc.sendCommand(SafeEncoder.encode(index), ModuleCommand.TVSCREATEINDEX, JoinParameters.joinParameters(SafeEncoder.encode(index), toByteArray(dims), SafeEncoder.encode(algorithm.name()), SafeEncoder.encode(method.name()), SafeEncoder.encodeMany(attrs)));
return BuilderFactory.STRING.build(obj);
}

@Override
public byte[] tvscreateindex(byte[] index, int dims, IndexAlgorithm algorithm, DistanceMethod method, final byte[]... params) {
Object obj = jc.sendCommand(index, ModuleCommand.TVSCREATEINDEX, JoinParameters.joinParameters(index, toByteArray(dims), SafeEncoder.encode(algorithm.name()), SafeEncoder.encode(method.name()), params));
return BuilderFactory.BYTE_ARRAY.build(obj);
Expand All @@ -64,11 +62,13 @@ public byte[] tvscreateindex(byte[] index, int dims, IndexAlgorithm algorithm, D
* @param index index name
* @return Success: string_map, Fail: empty
*/
@Override
public Map<String, String> tvsgetindex(final String index) {
Object obj = jc.sendCommand(SafeEncoder.encode(index), ModuleCommand.TVSGETINDEX, SafeEncoder.encode(index));
return BuilderFactory.STRING_MAP.build(obj);
}

@Override
public Map<byte[], byte[]> tvsgetindex(byte[] index) {
Object obj = jc.sendCommand(index, ModuleCommand.TVSGETINDEX, index);
return BuilderFactory.BYTE_ARRAY_MAP.build(obj);
Expand All @@ -82,37 +82,18 @@ public Map<byte[], byte[]> tvsgetindex(byte[] index) {
* @param index index name
* @return Success: 1; Fail: 0
*/
@Override
public Long tvsdelindex(final String index) {
Object obj = jc.sendCommand(SafeEncoder.encode(index), ModuleCommand.TVSDELINDEX, SafeEncoder.encode(index));
return BuilderFactory.LONG.build(obj);
}

@Override
public Long tvsdelindex(byte[] index) {
Object obj = jc.sendCommand(index, ModuleCommand.TVSDELINDEX, index);
return BuilderFactory.LONG.build(obj);
}


/**
* TVS.SCANINDEX TVS.SCANINDEX index_name
* <p>
* scan index
*
* @param cursor start offset
* @param params the params: [MATCH pattern] [COUNT count]
* `MATCH` - Set the pattern which is used to filter the results
* `COUNT` - Set the number of fields in a single scan (default is 10)
* `NOVAL` - The return result contains no data portion, only cursor information
* @return A ScanResult. {@link HashBuilderFactory#EXHSCAN_RESULT_STRING}
*/
public ScanResult<String> tvsscanindex(Long cursor, HscanParams params) {
final List<byte[]> args = new ArrayList<byte[]>();
args.add(toByteArray(cursor));
args.addAll(params.getParams());
Object obj = jc.sendCommand(toByteArray(cursor), ModuleCommand.TVSSCANINDEX, args.toArray(new byte[args.size()][]));
return VectorBuilderFactory.SCAN_CURSOR_STRING.build(obj);
}

/**
* TVS.HSET TVS.HSET index entityid vector [(attribute_key attribute_value) ...]
* <p>
Expand All @@ -126,11 +107,13 @@ public ScanResult<String> tvsscanindex(Long cursor, HscanParams params) {
* {@literal k} if success, k is the number of fields that were added..
* throw error like "(error) Illegal vector dimensions" if error
*/
@Override
public Long tvshset(final String index, final String entityid, final String vector, final String... params) {
Object obj = jc.sendCommand(SafeEncoder.encode(index), ModuleCommand.TVSHSET, JoinParameters.joinParameters(SafeEncoder.encode(index), SafeEncoder.encode(entityid), SafeEncoder.encode(VectorBuilderFactory.VECTOR_TAG), SafeEncoder.encode(vector), SafeEncoder.encodeMany(params)));
return BuilderFactory.LONG.build(obj);
}

@Override
public Long tvshset(byte[] index, byte[] entityid, byte[] vector, final byte[]... params) {
Object obj = jc.sendCommand(index, ModuleCommand.TVSHSET, JoinParameters.joinParameters(index, entityid, SafeEncoder.encode(VectorBuilderFactory.VECTOR_TAG), vector, params));
return BuilderFactory.LONG.build(obj);
Expand All @@ -145,11 +128,13 @@ public Long tvshset(byte[] index, byte[] entityid, byte[] vector, final byte[]..
* @param entityid entity id
* @return Map, an empty list when {@code entityid} does not exist.
*/
@Override
public Map<String, String> tvshgetall(final String index, final String entityid) {
Object obj = jc.sendCommand(SafeEncoder.encode(index), ModuleCommand.TVSHGETALL, SafeEncoder.encode(index), SafeEncoder.encode(entityid));
return BuilderFactory.STRING_MAP.build(obj);
}

@Override
public Map<byte[], byte[]> tvshgetall(byte[] index, byte[] entityid) {
Object obj = jc.sendCommand(index, ModuleCommand.TVSHGETALL, index, entityid);
return BuilderFactory.BYTE_ARRAY_MAP.build(obj);
Expand All @@ -165,11 +150,13 @@ public Map<byte[], byte[]> tvshgetall(byte[] index, byte[] entityid) {
* @param attrs attrs
* @return List, an empty list when {@code entityid} or {@code attrs} does not exist .
*/
@Override
public List<String> tvshmget(final String index, final String entityid, final String... attrs) {
Object obj = jc.sendCommand(SafeEncoder.encode(index), ModuleCommand.TVSHMGET, JoinParameters.joinParameters(SafeEncoder.encode(index), SafeEncoder.encode(entityid), SafeEncoder.encodeMany(attrs)));
return BuilderFactory.STRING_LIST.build(obj);
}

@Override
public List<byte[]> tvshmget(byte[] index, byte[] entityid, byte[]... attrs) {
Object obj = jc.sendCommand(index, ModuleCommand.TVSHMGET, JoinParameters.joinParameters(index, entityid, attrs));
return BuilderFactory.BYTE_ARRAY_LIST.build(obj);
Expand All @@ -186,11 +173,13 @@ public List<byte[]> tvshmget(byte[] index, byte[] entityid, byte[]... attrs) {
* @return Long integer-reply the number of fields that were removed from the tair-vector
* not including specified but non existing fields.
*/
@Override
public Long tvsdel(final String index, final String entityid) {
Object obj = jc.sendCommand(SafeEncoder.encode(index), ModuleCommand.TVSDEL, SafeEncoder.encode(index), SafeEncoder.encode(entityid));
return BuilderFactory.LONG.build(obj);
}

@Override
public Long tvsdel(byte[] index, byte[] entityid) {
Object obj = jc.sendCommand(index, ModuleCommand.TVSDEL, index, entityid);
return BuilderFactory.LONG.build(obj);
Expand All @@ -207,13 +196,15 @@ public Long tvsdel(byte[] index, byte[] entityid) {
* @return Long integer-reply the number of fields that were removed from the tair-vector
* not including specified but non existing fields.
*/
@Override
public Long tvshdel(final String index, final String entityid, final String... attrs) {
Object obj = jc.sendCommand(SafeEncoder.encode(index), ModuleCommand.TVSHDEL, JoinParameters.joinParameters(SafeEncoder.encode(index), SafeEncoder.encode(entityid), SafeEncoder.encodeMany(attrs)));
return BuilderFactory.LONG.build(obj);
}

@Override
public Long tvshdel(byte[] index, byte[] entityid, byte[]... attrs) {
Object obj = jc.sendCommand(index, ModuleCommand.TVSHDEL, JoinParameters.joinParameters(index, entityid, attrs));
Object obj = jc.sendCommand(index, ModuleCommand.TVSHDEL, JoinParameters.joinParameters(index, entityid,attrs));
return BuilderFactory.LONG.build(obj);
}

Expand All @@ -231,6 +222,7 @@ public Long tvshdel(byte[] index, byte[] entityid, byte[]... attrs) {
* `NOVAL` - The return result contains no data portion, only cursor information
* @return A ScanResult.
*/
@Override
public ScanResult<String> tvsscan(final String index, Long cursor, HscanParams params) {
final List<byte[]> args = new ArrayList<byte[]>();
args.add(SafeEncoder.encode(index));
Expand All @@ -240,6 +232,7 @@ public ScanResult<String> tvsscan(final String index, Long cursor, HscanParams p
return VectorBuilderFactory.SCAN_CURSOR_STRING.build(obj);
}

@Override
public ScanResult<byte[]> tvsscan(byte[] index, Long cursor, HscanParams params) {
final List<byte[]> args = new ArrayList<byte[]>();
args.add(index);
Expand All @@ -261,10 +254,12 @@ public ScanResult<byte[]> tvsscan(byte[] index, Long cursor, HscanParams params)
* ef_search range [0, 1000]
* @return VectorBuilderFactory.Knn<>
*/
@Override
public VectorBuilderFactory.Knn<String> tvsknnsearch(final String index, Long topn, final String vector, final String... params) {
return tvsknnsearchfilter(index, topn, vector, "", params);
}

@Override
public VectorBuilderFactory.Knn<byte[]> tvsknnsearch(byte[] index, Long topn, byte[] vector, final byte[]... params) {
return tvsknnsearchfilter(index, topn, vector, SafeEncoder.encode(""), params);
}
Expand All @@ -282,12 +277,14 @@ public VectorBuilderFactory.Knn<byte[]> tvsknnsearch(byte[] index, Long topn, by
* ef_search range [0, 1000]
* @return VectorBuilderFactory.Knn<>
*/
@Override
public VectorBuilderFactory.Knn<String> tvsknnsearchfilter(final String index, Long topn, final String vector, final String pattern, final String... params) {
Object obj = jc.sendCommand(SafeEncoder.encode(index), ModuleCommand.TVSKNNSEARCH, JoinParameters.joinParameters(SafeEncoder.encode(index), toByteArray(topn),
SafeEncoder.encode(vector), SafeEncoder.encode(pattern), SafeEncoder.encodeMany(params)));
return VectorBuilderFactory.STRING_KNN_RESULT.build(obj);
}

@Override
public VectorBuilderFactory.Knn<byte[]> tvsknnsearchfilter(byte[] index, Long topn, byte[] vector, byte[] pattern, final byte[]... params) {
Object obj = jc.sendCommand(index, ModuleCommand.TVSKNNSEARCH, JoinParameters.joinParameters(index, toByteArray(topn), vector, pattern, params));
return VectorBuilderFactory.BYTE_KNN_RESULT.build(obj);
Expand All @@ -303,10 +300,12 @@ public VectorBuilderFactory.Knn<byte[]> tvsknnsearchfilter(byte[] index, Long to
* ef_search range [0, 1000]
* @return Collection<>
*/
@Override
public Collection<VectorBuilderFactory.Knn<String>> tvsmknnsearch(final String index, Long topn, Collection<String> vectors, final String... params) {
return tvsmknnsearchfilter(index, topn, vectors, "", params);
}

@Override
public Collection<VectorBuilderFactory.Knn<byte[]>> tvsmknnsearch(byte[] index, Long topn, Collection<byte[]> vectors, final byte[]... params) {
return tvsmknnsearchfilter(index, topn, vectors, SafeEncoder.encode(""), params);
}
Expand All @@ -322,6 +321,7 @@ public Collection<VectorBuilderFactory.Knn<byte[]>> tvsmknnsearch(byte[] index,
* ef_search range [0, 1000]
* @return Collection<>
*/
@Override
public Collection<VectorBuilderFactory.Knn<String>> tvsmknnsearchfilter(final String index, Long topn, Collection<String> vectors, final String pattern, final String... params) {
final List<byte[]> args = new ArrayList<byte[]>();
args.add(SafeEncoder.encode(index));
Expand All @@ -334,6 +334,7 @@ public Collection<VectorBuilderFactory.Knn<String>> tvsmknnsearchfilter(final St
return VectorBuilderFactory.STRING_KNN_BATCH_RESULT.build(obj);
}

@Override
public Collection<VectorBuilderFactory.Knn<byte[]>> tvsmknnsearchfilter(byte[] index, Long topn, Collection<byte[]> vectors, byte[] pattern, final byte[]... params) {
final List<byte[]> args = new ArrayList<byte[]>();
args.add(index);
Expand All @@ -345,4 +346,6 @@ public Collection<VectorBuilderFactory.Knn<byte[]>> tvsmknnsearchfilter(byte[] i
Object obj = jc.sendCommand(index, ModuleCommand.TVSMKNNSEARCH, args.toArray(new byte[args.size()][]));
return VectorBuilderFactory.BYTE_KNN_BATCH_RESULT.build(obj);
}


}
10 changes: 3 additions & 7 deletions src/main/java/com/aliyun/tair/tairvector/TairVectorPipeline.java
Original file line number Diff line number Diff line change
@@ -1,12 +1,5 @@
package com.aliyun.tair.tairvector;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import com.aliyun.tair.ModuleCommand;
import com.aliyun.tair.tairhash.factory.HashBuilderFactory;
import com.aliyun.tair.tairvector.factory.VectorBuilderFactory;
Expand All @@ -20,6 +13,9 @@
import redis.clients.jedis.ScanResult;
import redis.clients.jedis.util.SafeEncoder;

import java.util.*;
import java.util.stream.Collectors;

import static redis.clients.jedis.Protocol.toByteArray;

public class TairVectorPipeline extends Pipeline {
Expand Down
Loading