Skip to content

Commit 62f30eb

Browse files
committed
publish dataset & add README
1 parent a484334 commit 62f30eb

File tree

11 files changed

+91
-22
lines changed

11 files changed

+91
-22
lines changed

README.md

+47-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,47 @@
1-
# SSP-MMC
1+
# SSP-MMC
2+
3+
Copyright (c) 2022 [MaiMemo](https://www.maimemo.com/), Inc. MIT License.
4+
5+
Stochastic-Shortest-Path-Minimize-Memorization-Cost (SSP-MMC) is a spaced repetition scheduling algorithm used to help learners remember more words in MaiMemo, a language learning application in China.
6+
7+
This repository contains a public release of the data and code used for several experiments in the following paper (which introduces SSP-MMC):
8+
9+
> Waiting for the result of SIGKDD2022
10+
11+
# Software
12+
13+
The file `data_preprocessing.py` is used to preprocess data for the DHP model.
14+
15+
The file `cal_model_param.py` contains the DHP model and HLR model.
16+
17+
The file `model/utils.py` saves the parameters of the DHP model for training and simulation.
18+
19+
The file `algo/main.cpp` contains a Cpp implementation of SSP-MMC, which aims at finding the optimal policy.
20+
21+
The file `simulator.py` provides an environment for comparing different scheduling algorithms.
22+
23+
## Workflow
24+
25+
1. Run `data_preprocessing.py` -> `halflife_for_fit.tsv`
26+
2. Run `cal_model_param.py` -> `intercept_` and `coef_` for the DHP model
27+
3. Save the parameters to the function `cal_recall_halflife` and ` cal_forget_halflife` in `model/utils.py` and the function `cal_next_recall_halflife` in `algo/main.cpp`
28+
4. Run `algo/main.cpp` -> optimal policy in `algo/result/`
29+
5. Run `simulator.py` to compare the SSP-MMC with several baselines.
30+
31+
## Data Set and Format
32+
33+
The dataset is available on [Dataverse](https://doi.org/10.7910/DVN/VAGUL0) (1.6 GB). This is a 7zipped TSV file containing our experiments' 220 million MaiMemo student memory behavior logs.
34+
35+
The columns are as follows:
36+
37+
- `u` - student user ID who reviewed the word (anonymized)
38+
- `w` - spelling of the word
39+
40+
- `i` - total times the user has reviewed the word
41+
- `d` - difficulty of the word
42+
- `t_history` - interval sequence of the historic reviews
43+
- `r_history` - recall sequence of the historic reviews
44+
- `delta_t` - time elapsed from the last review
45+
- `r` - result of the review
46+
- `p_recall` - probability of recall
47+
- `total_cnt` - number of users who did the same memory behavior

algo/main.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ float cal_start_halflife(int difficulty) {
2222

2323
float cal_next_recall_halflife(float h, float p, int d, int recall) {
2424
if (recall == 1) {
25-
return h * (1 + exp(3.80863264) * pow(d, -0.53420593) * pow(h, -0.127362) * pow(1 - p, 0.967804));
25+
return h * (1 + exp(3.81) * pow(d, -0.534) * pow(h, -0.127) * pow(1 - p, 0.97));
2626
} else {
27-
return exp(-0.04158382) * pow(d, -0.04067209) * pow(h, 0.37745957) * pow(1 - p, -0.22724425);
27+
return exp(-0.041) * pow(d, -0.041) * pow(h, 0.377) * pow(1 - p, -0.227);
2828
}
2929
}
3030

algo/result/README.md

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Introduction
2+
3+
The CSV files whose name begin with `cost` record the expected review cost for each memory state.
4+
5+
The CSV files whose name begin with `ivl` record the optimal review interval for each memory state.
6+
7+
The CSV files whose name begin with `recall` record the recall probability corresponding to the optimal review interval for each memory state.

cal_model_param.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,11 @@ def fit_recall_halflife(raw):
7575
raw['predict_halflife_hlr'] = y_pred
7676
fig = go.Figure()
7777
fig.add_trace(
78-
go.Scatter(x=raw['halflife'], y=raw['predict_halflife_dhp'], marker_size=np.log(raw['group_cnt']) / 2,
78+
go.Scatter(x=raw['halflife'], y=raw['predict_halflife_dhp'], marker_size=np.log(raw['group_cnt']),
7979
mode='markers',
8080
name='DHP'))
8181
fig.add_trace(
82-
go.Scatter(x=raw['halflife'], y=raw['predict_halflife_hlr'], marker_size=np.log(raw['group_cnt']) / 2,
82+
go.Scatter(x=raw['halflife'], y=raw['predict_halflife_hlr'], marker_size=np.log(raw['group_cnt']),
8383
mode='markers',
8484
name='HLR', opacity=0.7))
8585
fig.update_xaxes(title_text='observed half-life after recall')
@@ -147,11 +147,11 @@ def fit_forget_halflife(raw):
147147
raw['predict_halflife_hlr'] = y_pred
148148
fig = go.Figure()
149149
fig.add_trace(
150-
go.Scatter(x=raw['halflife'], y=raw['predict_halflife_dhp'], marker_size=np.log(raw['group_cnt']) / 2,
150+
go.Scatter(x=raw['halflife'], y=raw['predict_halflife_dhp'], marker_size=np.log(raw['group_cnt']),
151151
mode='markers',
152152
name='DHP'))
153153
fig.add_trace(
154-
go.Scatter(x=raw['halflife'], y=raw['predict_halflife_hlr'], marker_size=np.log(raw['group_cnt']) / 2,
154+
go.Scatter(x=raw['halflife'], y=raw['predict_halflife_hlr'], marker_size=np.log(raw['group_cnt']),
155155
mode='markers',
156156
name='HLR', opacity=0.7))
157157
fig.update_xaxes(title_text='observed half-life after forget')

data/README.md

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Introduction
2+
3+
Please unzip the dataset files to this directory. The necessary files include:
4+
5+
- `opensource_dataset_difficulty.tsv`
6+
- `opensource_dataset_forgetting_curve.tsv`

fit_data.py

-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import sys
44
import os
55
import time
6-
import numpy as np
76
import pandas as pd
87
import math
98
from collections import namedtuple

model/utils.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,16 @@ def cal_start_halflife(difficulty):
2323

2424
def cal_recall_halflife(difficulty, halflife, p_recall):
2525
return halflife * (
26-
1 + np.exp(3.80863264) * np.power(difficulty, -0.53420593) * np.power(halflife, -0.127362) * np.power(
27-
1 - p_recall, 0.9678043))
26+
1 + np.exp(3.81) * np.power(difficulty, -0.534) * np.power(halflife, -0.127) * np.power(
27+
1 - p_recall, 0.97))
2828

2929

3030
def cal_forget_halflife(difficulty, halflife, p_recall):
31-
return np.exp(-0.04158382) * np.power(difficulty, -0.04067209) * np.power(halflife, 0.37745957) * np.power(
32-
1 - p_recall, -0.22724425)
31+
return np.exp(-0.041) * np.power(difficulty, -0.041) * np.power(halflife, 0.377) * np.power(
32+
1 - p_recall, -0.227)
3333

3434

35+
# the following code is from https://github.com/Networks-Learning/memorize
3536
def intensity(t, n_t, q):
3637
return 1.0 / np.sqrt(q) * (1 - np.exp(-n_t * t))
3738

plot/README.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Introduction
2+
3+
The PDF files show the statistical information about the dataset.
4+
5+
The HTML files are the dynamic version.

simulation_result/README.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Introduction
2+
3+
The TSV files record the review history of each word during the simulation. Their names represent the scheduling algorithm they used.
4+
5+
The PDF files show the processes of simulation.

simulator.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -200,23 +200,23 @@ def scheduler(difficulty, halflife, reps, lapses, method):
200200
total_cost = int(sum(cost_per_day))
201201

202202
plt.figure(1)
203-
plt.plot(record_per_day, label=f'{method}', linewidth=0.8)
203+
plt.plot(record_per_day, label=f'{method}')
204204

205205
plt.figure(2)
206-
plt.plot(meet_target_per_day, label=f'{method}', linewidth=0.8)
206+
plt.plot(meet_target_per_day, label=f'{method}')
207207
cost_day = np.argmax(meet_target_per_day >= compare_target)
208208
if cost_day > 0:
209209
print(f'cost day: {cost_day}')
210210
plt.plot(cost_day, compare_target, 'k*', linewidth=2)
211211

212212
plt.figure(3)
213-
plt.plot(new_item_per_day_average_per_period, label=f'{method}', linewidth=0.8)
213+
plt.plot(new_item_per_day_average_per_period, label=f'{method}')
214214

215215
plt.figure(4)
216-
plt.plot(cost_per_day_average_per_period, label=f'{method}', linewidth=0.8)
216+
plt.plot(cost_per_day_average_per_period, label=f'{method}')
217217

218218
plt.figure(5)
219-
plt.plot(learned_per_day, label=f'{method}', linewidth=0.8)
219+
plt.plot(learned_per_day, label=f'{method}')
220220

221221
print('acc learn', total_learned)
222222
print('meet target', meet_target)
@@ -236,7 +236,7 @@ def scheduler(difficulty, halflife, reps, lapses, method):
236236
# plt.plot(....)
237237
pdf.savefig()
238238
plt.figure(2)
239-
plt.plot((0, learn_days), (compare_target, compare_target), color='black', linestyle='dotted', linewidth=0.8)
239+
plt.plot((0, learn_days), (compare_target, compare_target), color='black', linestyle='dotted')
240240
plt.title(f"day cost limit:{day_cost_limit}-learn days:{learn_days}")
241241
plt.xlabel("days")
242242
plt.ylabel("THR")

visualization.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,8 @@ def policy_action_visualize():
234234

235235

236236
if __name__ == "__main__":
237-
# difficulty_visualize()
238-
# forgetting_curve_visualize()
239-
# raw_data_visualize()
240-
# dhp_model_visualize()
237+
difficulty_visualize()
238+
forgetting_curve_visualize()
239+
raw_data_visualize()
240+
dhp_model_visualize()
241241
policy_action_visualize()

0 commit comments

Comments
 (0)