1
+ import os
2
+ import sys
3
+ import tarfile
4
+ import time
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ from packaging import version
9
+ from torch .utils .data import Dataset
10
+ from tqdm import tqdm
11
+ import urllib
12
+
13
+
14
+ def reporthook (count , block_size , total_size ):
15
+ global start_time
16
+ if count == 0 :
17
+ start_time = time .time ()
18
+ return
19
+ duration = time .time () - start_time
20
+ progress_size = int (count * block_size )
21
+ speed = progress_size / (1024.0 ** 2 * duration )
22
+ percent = count * block_size * 100.0 / total_size
23
+
24
+ sys .stdout .write (
25
+ f"\r { int (percent )} % | { progress_size / (1024. ** 2 ):.2f} MB "
26
+ f"| { speed :.2f} MB/s | { duration :.2f} sec elapsed"
27
+ )
28
+ sys .stdout .flush ()
29
+
30
+
31
+ def download_dataset ():
32
+ source = "http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz"
33
+ target = "aclImdb_v1.tar.gz"
34
+
35
+ if os .path .exists (target ):
36
+ os .remove (target )
37
+
38
+ if not os .path .isdir ("aclImdb" ) and not os .path .isfile ("aclImdb_v1.tar.gz" ):
39
+ urllib .request .urlretrieve (source , target , reporthook )
40
+
41
+ if not os .path .isdir ("aclImdb" ):
42
+
43
+ with tarfile .open (target , "r:gz" ) as tar :
44
+ tar .extractall ()
45
+
46
+
47
+ def load_dataset_into_to_dataframe ():
48
+ basepath = "aclImdb"
49
+
50
+ labels = {"pos" : 1 , "neg" : 0 }
51
+
52
+ df = pd .DataFrame ()
53
+
54
+ with tqdm (total = 50000 ) as pbar :
55
+ for s in ("test" , "train" ):
56
+ for l in ("pos" , "neg" ):
57
+ path = os .path .join (basepath , s , l )
58
+ for file in sorted (os .listdir (path )):
59
+ with open (os .path .join (path , file ), "r" , encoding = "utf-8" ) as infile :
60
+ txt = infile .read ()
61
+
62
+ if version .parse (pd .__version__ ) >= version .parse ("1.3.2" ):
63
+ x = pd .DataFrame (
64
+ [[txt , labels [l ]]], columns = ["review" , "sentiment" ]
65
+ )
66
+ df = pd .concat ([df , x ], ignore_index = False )
67
+
68
+ else :
69
+ df = df .append ([[txt , labels [l ]]], ignore_index = True )
70
+ pbar .update ()
71
+ df .columns = ["text" , "label" ]
72
+
73
+ np .random .seed (0 )
74
+ df = df .reindex (np .random .permutation (df .index ))
75
+
76
+ print ("Class distribution:" )
77
+ np .bincount (df ["label" ].values )
78
+
79
+ return df
80
+
81
+
82
+ def partition_dataset (df ):
83
+ df_shuffled = df .sample (frac = 1 , random_state = 1 ).reset_index ()
84
+
85
+ df_train = df_shuffled .iloc [:35_000 ]
86
+ df_val = df_shuffled .iloc [35_000 :40_000 ]
87
+ df_test = df_shuffled .iloc [40_000 :]
88
+
89
+ df_train .to_csv ("train.csv" , index = False , encoding = "utf-8" )
90
+ df_val .to_csv ("val.csv" , index = False , encoding = "utf-8" )
91
+ df_test .to_csv ("test.csv" , index = False , encoding = "utf-8" )
92
+
93
+
94
+ class IMDBDataset (Dataset ):
95
+ def __init__ (self , dataset_dict , partition_key = "train" ):
96
+ self .partition = dataset_dict [partition_key ]
97
+
98
+ def __getitem__ (self , index ):
99
+ return self .partition [index ]
100
+
101
+ def __len__ (self ):
102
+ return self .partition .num_rows
0 commit comments