Skip to content

Commit 1bdcb7b

Browse files
committed
Add Scorer
1 parent 8cbc004 commit 1bdcb7b

File tree

5 files changed

+155
-5
lines changed

5 files changed

+155
-5
lines changed

build.gradle

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,5 +43,5 @@ allprojects {
4343

4444
ext {
4545
pomGroupID = 'ai.improve'
46-
sdkVersion = '7.2.0'
46+
sdkVersion = '8.0.0'
4747
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package ai.improve;
2+
3+
import androidx.test.ext.junit.runners.AndroidJUnit4;
4+
5+
import org.junit.Test;
6+
import org.junit.runner.RunWith;
7+
8+
import java.io.IOException;
9+
import java.util.Arrays;
10+
import java.util.List;
11+
12+
import ai.improve.log.IMPLog;
13+
14+
@RunWith(AndroidJUnit4.class)
15+
public class TestScorer {
16+
public static final String Tag = "TestScorer";
17+
18+
public static final String ModelUrl = "https://improveai-mindblown-mindful-prod-models.s3.amazonaws.com/models/latest/songs-2.0.xgb.gz";
19+
20+
public static final String DummyV8ModelUrl = "file:///android_asset/dummy_v8.xgb";
21+
22+
static {
23+
IMPLog.setLogLevel(IMPLog.LOG_LEVEL_ALL);
24+
}
25+
26+
@Test
27+
public void testScore() throws IOException, InterruptedException {
28+
Scorer scorer = new Scorer(ModelUrl);
29+
List scores = scorer.score(Arrays.asList(0, 1, 2));
30+
IMPLog.d(Tag, "scores: " + scores);
31+
}
32+
33+
@Test
34+
public void testLoad_v8_model() throws IOException, InterruptedException {
35+
Scorer scorer = new Scorer(DummyV8ModelUrl);
36+
List scores = scorer.score(Arrays.asList(0, 1, 2));
37+
IMPLog.d(Tag, "scores: " + scores);
38+
}
39+
40+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
package ai.improve;
2+
3+
import java.io.IOException;
4+
import java.net.URL;
5+
import java.util.ArrayList;
6+
import java.util.List;
7+
import java.util.concurrent.CountDownLatch;
8+
9+
import ai.improve.downloader.ModelDownloader;
10+
import ai.improve.encoder.FeatureEncoder;
11+
import ai.improve.log.IMPLog;
12+
import ai.improve.xgbpredictor.ImprovePredictor;
13+
import biz.k11i.xgboost.util.FVec;
14+
15+
public class Scorer {
16+
public static final String Tag = "Scorer";
17+
18+
private CountDownLatch loadModelSignal = new CountDownLatch(1);
19+
20+
private ImprovePredictor predictor;
21+
22+
private FeatureEncoder featureEncoder;
23+
24+
private boolean enableTieBreaker = true;
25+
26+
public Scorer(String modelUrl) throws IOException, InterruptedException {
27+
loadModel(new URL(modelUrl));
28+
if(predictor == null) {
29+
throw new IOException("Failed to load model " + modelUrl);
30+
}
31+
}
32+
33+
public <T> List<Double> score(List<?> items) {
34+
return score(items, null);
35+
}
36+
37+
public <T> List<Double> score(List<?> items, T context) {
38+
if(items == null || items.size() <= 0) {
39+
throw new IllegalArgumentException("variants can't be null or empty");
40+
}
41+
42+
List<Double> result = new ArrayList<>();
43+
List<FVec> encodedFeatures = featureEncoder.encodeFeatureVectors(items, context, Math.random());
44+
for (FVec fvec : encodedFeatures) {
45+
46+
if(enableTieBreaker) {
47+
// add a very small random number to randomly break ties
48+
double smallNoise = Math.random() * Math.pow(2, -23);
49+
result.add((double) predictor.predictSingle(fvec) + smallNoise);
50+
} else {
51+
result.add((double) predictor.predictSingle(fvec));
52+
}
53+
}
54+
55+
return result;
56+
}
57+
58+
private synchronized void setModel(ImprovePredictor predictor) {
59+
this.predictor = predictor;
60+
61+
featureEncoder = new FeatureEncoder(predictor.getModelMetadata().getModelFeatureNames(),
62+
predictor.getModelMetadata().getStringTables(),
63+
predictor.getModelMetadata().getModelSeed());
64+
}
65+
66+
private void loadModel(URL modelUrl) throws InterruptedException {
67+
ModelDownloader.download(modelUrl, (predictor, e) -> {
68+
if(e != null) {
69+
IMPLog.e(Tag, "Failed to load model, " + e.getMessage());
70+
loadModelSignal.countDown();
71+
return;
72+
}
73+
74+
setModel(predictor);
75+
loadModelSignal.countDown();
76+
77+
});
78+
loadModelSignal.await();
79+
}
80+
}

improveai/src/main/java/ai/improve/encoder/FeatureEncoder.java

+9
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,16 @@ public void encodeFeatureVector(Object item, Object context, double[] into, doub
113113
if (context != null) {
114114
encodeContext(context, into, noiseShiftAndScale[0], noiseShiftAndScale[1]);
115115
}
116+
}
116117

118+
public List<FVec> encodeFeatureVectors(List<?> items, Object context, double noise) {
119+
List<FVec> result = new ArrayList<>();
120+
for(int i = 0; i < items.size(); ++i) {
121+
double[] fvalues = new double[this.featureIndexes.size()];
122+
encodeFeatureVector(items.get(i), context, fvalues, noise);
123+
result.add(FVec.Transformer.fromArray(fvalues, false));
124+
}
125+
return result;
117126
}
118127

119128
/**

improveai/src/main/java/ai/improve/xgbpredictor/ModelMetadata.java

+25-4
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,15 @@
22

33
import ai.improve.constants.BuildProperties;
44
import biz.k11i.xgboost.util.ModelReader;
5+
6+
import com.google.gson.Gson;
57
import com.google.gson.JsonArray;
68
import com.google.gson.JsonObject;
79
import com.google.gson.JsonParser;
10+
import com.google.gson.reflect.TypeToken;
811

912
import java.io.IOException;
13+
import java.lang.reflect.Type;
1014
import java.util.ArrayList;
1115
import java.util.HashMap;
1216
import java.util.List;
@@ -19,6 +23,14 @@ public class ModelMetadata {
1923

2024
public static final String IMPROVE_VERSION_KEY = "ai.improve.version";
2125

26+
public static final String IMPROVE_SEED_KEY = "ai.improve.seed";
27+
28+
public static final String IMPROVE_MODEL_NAME_KEY = "ai.improve.model";
29+
30+
public static final String IMPROVE_FEAtURES_KEY = "ai.improve.features";
31+
32+
public static final String IMPROVE_STRING_TABLES_KEY = "ai.improve.string_tables";
33+
2234
private Map<String, String> storage = new HashMap<>();
2335

2436
private String modelName;
@@ -27,6 +39,8 @@ public class ModelMetadata {
2739

2840
private List<String> modelFeatureNames;
2941

42+
private Map<String, List<Long>> stringTables;
43+
3044
public ModelMetadata(ModelReader r) throws IOException {
3145
long num_attrs = r.readLong();
3246
for (long i = 0; i < num_attrs; ++i) {
@@ -56,6 +70,10 @@ public List<String> getModelFeatureNames() {
5670
return modelFeatureNames;
5771
}
5872

73+
public Map<String, List<Long>> getStringTables() {
74+
return stringTables;
75+
}
76+
5977
public String getValue(String key) {
6078
return storage.get(key);
6179
}
@@ -66,22 +84,25 @@ public String getUserDefinedMetadata() {
6684

6785
private void parseMetadata(String value) throws IOException {
6886
try {
69-
JsonObject root = JsonParser.parseString(value).getAsJsonObject().getAsJsonObject("json");
87+
JsonObject root = JsonParser.parseString(value).getAsJsonObject();
7088
if(root.has(IMPROVE_VERSION_KEY)) {
7189
String modelVersion = root.get(IMPROVE_VERSION_KEY).getAsString();
7290
if(!canParseModel(modelVersion, BuildProperties.getSDKVersion())) {
7391
throw new IOException("Major version don't match. ImproveAI SDK version(" + BuildProperties.getSDKVersion()+") " +
7492
"can't load the model of version("+ modelVersion + ").");
7593
}
7694
}
77-
modelName = root.get("model_name").getAsString();
78-
modelSeed = root.get("model_seed").getAsLong();
95+
modelName = root.get(IMPROVE_MODEL_NAME_KEY).getAsString();
96+
modelSeed = root.get(IMPROVE_SEED_KEY).getAsLong();
7997

80-
JsonArray featuresArray = root.get("feature_names").getAsJsonArray();
98+
JsonArray featuresArray = root.get(IMPROVE_FEAtURES_KEY).getAsJsonArray();
8199
modelFeatureNames = new ArrayList<>(featuresArray.size());
82100
for (int i = 0; i < featuresArray.size(); ++i) {
83101
modelFeatureNames.add(featuresArray.get(i).getAsString());
84102
}
103+
104+
Type type = new TypeToken<Map<String, List<Long>>>(){}.getType();
105+
stringTables = new Gson().fromJson(root.get(IMPROVE_STRING_TABLES_KEY), type);
85106
} catch (RuntimeException e) {
86107
throw new IOException("Failed to parse the model metadata. Looks like the model being loaded is invalid.");
87108
}

0 commit comments

Comments
 (0)