-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathbinary_converter.py
152 lines (122 loc) · 5.04 KB
/
binary_converter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#!/usr/bin/env python3
"""Converting between floats and binaries.
This code converts tensors of floats or bits into the respective other.
We use the IEEE-754 guideline [1] to convert. The default for conversion are
based on 32 bit / single precision floats: 8 exponent bits and 23 mantissa bits.
Other common formats are
num total bits precision exponent bits mantissa bits bias
---------------------------------------------------------------------------
64 bits double 11 52 1023
32 bits single 8 23 127
16 bits half 5 10 15
Available modules:
* bit2float
* float2bit
* integer2bit
* remainder2bit
[1] IEEE Computer Society (2008-08-29). IEEE Standard for Floating-Point
Arithmetic. IEEE Std 754-2008. IEEE. pp. 1–70. doi:10.1109/IEEESTD.2008.4610935.
ISBN 978-0-7381-5753-5. IEEE Std 754-2008
Author, Karen Ullrich June 2019
"""
import torch
import warnings
def bit2float(b, num_e_bits=8, num_m_bits=23, bias=127.):
"""Turn input tensor into float.
Args:
b : binary tensor. The last dimension of this tensor should be the
the one the binary is at.
num_e_bits : Number of exponent bits. Default: 8.
num_m_bits : Number of mantissa bits. Default: 23.
bias : Exponent bias/ zero offset. Default: 127.
Returns:
Tensor: Float tensor. Reduces last dimension.
"""
expected_last_dim = num_m_bits + num_e_bits + 1
assert b.shape[-1] == expected_last_dim, "Binary tensors last dimension " \
"should be {}, not {}.".format(
expected_last_dim, b.shape[-1])
# check if we got the right type
dtype = torch.float32
if expected_last_dim > 32: dtype = torch.float64
if expected_last_dim > 64:
warnings.warn("pytorch can not process floats larger than 64 bits, keep"
" this in mind. Your result will be not exact.")
s = torch.index_select(b, -1, torch.arange(0, 1))
e = torch.index_select(b, -1, torch.arange(1, 1 + num_e_bits))
m = torch.index_select(b, -1, torch.arange(1 + num_e_bits,
1 + num_e_bits + num_m_bits))
# SIGN BIT
out = ((-1) ** s).squeeze(-1).type(dtype)
# EXPONENT BIT
exponents = -torch.arange(-(num_e_bits - 1.), 1.)
exponents = exponents.repeat(b.shape[:-1] + (1,))
e_decimal = torch.sum(e * 2 ** exponents, dim=-1) - bias
out *= 2 ** e_decimal
# MANTISSA
matissa = (torch.Tensor([2.]) ** (
-torch.arange(1., num_m_bits + 1.))).repeat(
m.shape[:-1] + (1,))
out *= 1. + torch.sum(m * matissa, dim=-1)
return out
def float2bit(f, num_e_bits=8, num_m_bits=23, bias=127., dtype=torch.float32):
"""Turn input tensor into binary.
Args:
f : float tensor.
num_e_bits : Number of exponent bits. Default: 8.
num_m_bits : Number of mantissa bits. Default: 23.
bias : Exponent bias/ zero offset. Default: 127.
dtype : This is the actual type of the tensor that is going to be
returned. Default: torch.float32.
Returns:
Tensor: Binary tensor. Adds last dimension to original tensor for
bits.
"""
## SIGN BIT
s = torch.sign(f)
f = f * s
# turn sign into sign-bit
s = (s * (-1) + 1.) * 0.5
s = s.unsqueeze(-1)
## EXPONENT BIT
e_scientific = torch.floor(torch.log2(f))
e_decimal = e_scientific + bias
e = integer2bit(e_decimal, num_bits=num_e_bits)
## MANTISSA
m1 = integer2bit(f - f % 1, num_bits=num_e_bits)
m2 = remainder2bit(f % 1, num_bits=bias)
m = torch.cat([m1, m2], dim=-1)
dtype = f.type()
idx = torch.arange(num_m_bits).unsqueeze(0).type(dtype) \
+ (8. - e_scientific).unsqueeze(-1)
idx = idx.long()
m = torch.gather(m, dim=-1, index=idx)
return torch.cat([s, e, m], dim=-1).type(dtype)
def remainder2bit(remainder, num_bits=127):
"""Turn a tensor with remainders (floats < 1) to mantissa bits.
Args:
remainder : torch.Tensor, tensor with remainders
num_bits : Number of bits to specify the precision. Default: 127.
Returns:
Tensor: Binary tensor. Adds last dimension to original tensor for
bits.
"""
dtype = remainder.type()
exponent_bits = torch.arange(num_bits).type(dtype)
exponent_bits = exponent_bits.repeat(remainder.shape + (1,))
out = (remainder.unsqueeze(-1) * 2 ** exponent_bits) % 1
return torch.floor(2 * out)
def integer2bit(integer, num_bits=8):
"""Turn integer tensor to binary representation.
Args:
integer : torch.Tensor, tensor with integers
num_bits : Number of bits to specify the precision. Default: 8.
Returns:
Tensor: Binary tensor. Adds last dimension to original tensor for
bits.
"""
dtype = integer.type()
exponent_bits = -torch.arange(-(num_bits - 1), 1).type(dtype)
exponent_bits = exponent_bits.repeat(integer.shape + (1,))
out = integer.unsqueeze(-1) / 2 ** exponent_bits
return (out - (out % 1)) % 2