From a415d3ead098062d81d310cc7ae1c899c293f677 Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Wed, 23 Apr 2025 22:28:48 -0500 Subject: [PATCH] feat: add tool for visualizing compressed (and uncompressed) models Add a development tool that prints compressed and uncompressed .tflite models to stdout in a human-readable, searchable, structured, text format. Helpful annotations (indexes of lists, names of operators, etc.) derived from the model are added as virtual fields with names beginning with an _underscore. Add a unit test which simply ensures the viewer does not crash when run on several models found in the source tree. BUG=see description --- tensorflow/lite/micro/compression/BUILD | 22 + tensorflow/lite/micro/compression/view.py | 424 ++++++++++++++++++ .../lite/micro/compression/view_test.py | 21 + .../lite/micro/compression/view_tests.bzl | 32 ++ third_party/python_requirements.in | 1 + third_party/python_requirements.txt | 9 + 6 files changed, 509 insertions(+) create mode 100644 tensorflow/lite/micro/compression/view.py create mode 100644 tensorflow/lite/micro/compression/view_test.py create mode 100644 tensorflow/lite/micro/compression/view_tests.bzl diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index 8e037260215..ea85f1c9fdf 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -10,6 +10,7 @@ load( ) load("@rules_python//python:defs.bzl", "py_test") load("@tflm_pip_deps//:requirements.bzl", "requirement") +load("view_tests.bzl", "generate_view_tests") package( default_visibility = [ @@ -190,3 +191,24 @@ py_test( requirement("tensorflow"), ], ) + +py_binary( + name = "view", + srcs = [ + "view.py", + ], + deps = [ + ":metadata_py", + "//tensorflow/lite/python:schema_py", + "@absl_py//absl:app", + requirement("bitarray"), + requirement("prettyprinter"), + ], +) + +generate_view_tests([ + "//tensorflow/lite/micro/models:keyword_scrambled.tflite", + "//tensorflow/lite/micro/models:keyword_scrambled_8bit.tflite", + "//tensorflow/lite/micro/models:person_detect.tflite", + "//tensorflow/lite/micro/models:person_detect_vela.tflite", +]) diff --git a/tensorflow/lite/micro/compression/view.py b/tensorflow/lite/micro/compression/view.py new file mode 100644 index 00000000000..d0407636a2a --- /dev/null +++ b/tensorflow/lite/micro/compression/view.py @@ -0,0 +1,424 @@ +# Copyright 2025 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This development tool prints compressed and uncompressed .tflite models to +# stdout in a human-readable, searchable, structured, text format. Helpful +# annotations (indexes of lists, names of operators, etc.) derived from the +# model are added as virtual fields with names beginning with an _underscore. +# +# # Theory of operation +# +# Convert the model into a Python dictionary, expressing the hierarchial nature +# of the model, and pretty print the dictionary. Please extend as needed for +# your use case. + +from dataclasses import dataclass +from enum import Enum +import bitarray +import bitarray.util +import numpy as np +import os +import prettyprinter +import prettyprinter.doc +import sys +import textwrap + +import absl.app + +from tensorflow.lite.micro.compression import metadata_py_generated as compression_schema +from tensorflow.lite.python import schema_py_generated as tflite_schema + +USAGE = textwrap.dedent(f"""\ + Usage: {os.path.basename(sys.argv[0])} $(realpath ) + Print a visualization of a .tflite model.""") + + +def print_model(model_path): + with open(model_path, 'rb') as flatbuffer: + d = create_dictionary(memoryview(flatbuffer.read())) + prettyprinter.cpprint(d) + + +def main(argv): + try: + model_path = argv[1] + except IndexError: + sys.stderr.write(USAGE) + sys.exit(1) + + print_model(model_path) + + +@dataclass +class MetadataReader: + model: tflite_schema.ModelT + buffer_index: int + metadata: compression_schema.MetadataT + + @classmethod + def build(cls, model: tflite_schema.ModelT): + if model.metadata is None: + return None + + for item in model.metadata: + if _decode_name(item.name) == "COMPRESSION_METADATA": + buffer_index = item.buffer + buffer = model.buffers[buffer_index] + metadata = compression_schema.MetadataT.InitFromPackedBuf( + buffer.data, 0) + if metadata.subgraphs is None: + raise ValueError("Invalid compression metadata") + return cls(model, buffer_index, metadata) + else: + return None + + def unpack(self): + result = [] + for index, subgraph in enumerate(self.metadata.subgraphs): + result.append({ + "_index": index, + "lut_tensors": unpack_lut_metadata(subgraph.lutTensors), + }) + return {"subgraphs": result} + + +def unpack_list(source, index_name="_index"): + result = [] + for index, s in enumerate(source): + d = {index_name: index} | vars(s) + result.append(d) + return result + + +def unpack_operators(model: tflite_schema.ModelT, + operators: list[tflite_schema.OperatorT]): + result = [] + for index, op in enumerate(operators): + opcode = model.operatorCodes[op.opcodeIndex] + name = OPERATOR_NAMES[opcode.builtinCode] + d = { + "_operator": index, + "opcode_index": op.opcodeIndex, + "_opcode_name": name, + "inputs": op.inputs, + "outputs": op.outputs, + } + result.append(d) + return result + + +def unpack_TensorType(type): + attrs = [ + attr for attr in dir(tflite_schema.TensorType) + if not attr.startswith("__") + ] + lut = {getattr(tflite_schema.TensorType, attr): attr for attr in attrs} + return lut[type] + + +def _decode_name(name): + """Returns name as a str or 'None'. + + The flatbuffer library returns names as bytes objects or None. This function + returns a str, decoded from the bytes object, or None. + """ + if name is None: + return None + else: + return str(name, encoding="utf-8") + + +@dataclass +class TensorCoordinates: + subgraph_ix: int + tensor_index: int + + +class CompressionMethod(Enum): + LUT = "LUT" + + +_NP_DTYPES = { + tflite_schema.TensorType.FLOAT16: np.dtype(" list[CompressionMethod]: + metadata = self._tensor_metadata(coordinates) + if metadata: + return [CompressionMethod.LUT] + else: + return [] + + def lookup_tables(self, coordinates: TensorCoordinates) -> np.ndarray: + metadata = self._tensor_metadata(coordinates) + if not metadata: + return np.array([]) + + model_subgraph = self.model.subgraphs[coordinates.subgraph_ix] + model_tensor = model_subgraph.tensors[coordinates.tensor_index] + value_buffer = self.model.buffers[metadata.valueBuffer] + values = np.frombuffer(bytes(value_buffer.data), + dtype=_NP_DTYPES[model_tensor.type]) + values_per_table = 2**metadata.indexBitwidth + tables = len(values) // values_per_table + values = values.reshape((tables, values_per_table)) + + return values + + +def unpack_tensors(tensors, subgraph_index: int, codec: Codec | None): + result = [] + for index, t in enumerate(tensors): + d = { + "_tensor": index, + "name": _decode_name(t.name), + "type": unpack_TensorType(t.type), + "shape": t.shape, + "buffer": t.buffer, + } + + if t.isVariable: + d["is_variable"] = True + else: + # don't display this unusual field + pass + + if t.quantization is not None and t.quantization.scale is not None: + d["quantization"] = { + "scale": t.quantization.scale, + "zero": t.quantization.zeroPoint, + "dimension": t.quantization.quantizedDimension, + } + result.append(d) + + if codec is not None: + coordinates = TensorCoordinates(subgraph_ix=subgraph_index, + tensor_index=index) + d |= unpack_compression(coordinates, codec) + + return result + + +def unpack_compression(tensor: TensorCoordinates, codec: Codec) -> dict: + result = {} + + compressions = codec.list_compressions(tensor) + if compressions: + result["_compressed"] = [c.name for c in compressions] + metadata = codec._tensor_metadata(tensor) + assert metadata is not None + result["_value_buffer"] = metadata.valueBuffer + result["_lookup_tables"] = codec.lookup_tables(tensor) + + return result + + +def unpack_subgraphs(model: tflite_schema.ModelT, codec: Codec | None): + result = [] + for index, s in enumerate(model.subgraphs): + d = { + "_subgraph": index, + "_operator_count": len(s.operators), + "_tensor_count": len(s.tensors), + "name": _decode_name(s.name), + "operators": unpack_operators(model, s.operators), + "tensors": unpack_tensors(s.tensors, subgraph_index=index, + codec=codec), + } + result.append(d) + return result + + +def unpack_opcodes(opcodes: list[tflite_schema.OperatorCodeT]) -> list: + result = [] + for index, opcode in enumerate(opcodes): + d: dict = { + "_opcode_index": index, + "_name": OPERATOR_NAMES[opcode.builtinCode], + "builtin_code": opcode.builtinCode, + "version": opcode.version, + } + if opcode.customCode is not None: + d["custom_code"] = opcode.customCode + del d["_name"] + result.append(d) + return result + + +def unpack_metadata(model: tflite_schema.ModelT): + entries = [] + compression = MetadataReader.build(model) + + if model.metadata is None: + return entries + + for m in model.metadata: + d = {"name": _decode_name(m.name), "buffer": m.buffer} + + if compression and compression.buffer_index == m.buffer: + d["_compression_metadata"] = compression.unpack() + + entries.append(d) + + return entries + + +def unpack_lut_metadata(lut_tensors): + return [{ + "tensor": t.tensor, + "value_buffer": t.valueBuffer, + "index_bitwidth": t.indexBitwidth, + } for t in sorted(lut_tensors, key=lambda x: x.tensor)] + + +def is_compressed_buffer(buffer_index, unpacked_metadata): + if unpacked_metadata is None: + return False, None, None + for subgraph in unpacked_metadata["subgraphs"]: + lut_list = subgraph["lut_tensors"] + subgraph_index = subgraph["_index"] + item = next( + (item for item in lut_list if item["value_buffer"] == buffer_index), + None) + if item is not None: + return True, item, subgraph_index + return False, None, None + + +def unpack_indices(buffer, lut_data): + bstring = bitarray.bitarray() + bstring.frombytes(bytes(buffer.data)) + bitwidth = lut_data["index_bitwidth"] + indices = [] + while len(bstring) > 0: + indices.append(bitarray.util.ba2int(bstring[0:bitwidth])) + del bstring[0:bitwidth] + return indices + + +def unpack_compression_metadata(buffer): + buffer = bytes(buffer.data) + metadata = compression_schema.MetadataT.InitFromPackedBuf(buffer, 0) + if metadata.subgraphs is None: + raise ValueError("Invalid compression metadata") + result = [] + for index, s in enumerate(metadata.subgraphs): + d = {"_index": index, "lut_tensors": unpack_lut_metadata(s.lutTensors)} + result.append(d) + return {"subgraphs": result} + + +def unpack_buffers(model, compression_data): + buffers = [] + for index, buffer in enumerate(model.buffers): + native = { + "_buffer": index, + "_bytes": len(buffer.data) if buffer.data is not None else 0, + } + + if compression_data is not None and index == compression_data.buffer_index: + native["_compression_metadata"] = True + + native["data"] = buffer.data + + buffers.append(native) + + return buffers + + +def get_compression_metadata_buffer(model): + """Returns the metadata buffer data or None.""" + if model.metadata is None: + return None + for item in model.metadata: + if _decode_name(item.name) == "COMPRESSION_METADATA": + return item.buffer + return None + + +def create_dictionary(flatbuffer: memoryview) -> dict: + """Returns a human-readable dictionary from the provided model flatbuffer. + + This function transforms a .tflite model flatbuffer into a Python dictionary. + When pretty-printed, this dictionary offers an easily interpretable view of + the model. + """ + model = tflite_schema.ModelT.InitFromPackedBuf(flatbuffer, 0) + compression_metadata = MetadataReader.build(model) + codec = Codec(compression_metadata, model) if compression_metadata else None + + output = { + "description": model.description, + "version": model.version, + "operator_codes": unpack_opcodes(model.operatorCodes), + # "operator_codes": unpack_list(model.operatorCodes), + "metadata": unpack_metadata(model), + "subgraphs": unpack_subgraphs(model, codec), + "buffers": unpack_buffers(model, compression_metadata), + } + + return output + + +@prettyprinter.register_pretty(np.ndarray) +def pretty_numpy_array(array, ctx): + string = np.array2string(array) + lines = string.splitlines() + + if len(lines) == 1: + return lines[0] + + parts = list() + parts.append(prettyprinter.doc.HARDLINE) + for line in lines: + parts.append(line) + parts.append(prettyprinter.doc.HARDLINE) + + return prettyprinter.doc.nest(ctx.indent, prettyprinter.doc.concat(parts)) + + +if __name__ == "__main__": + sys.modules['__main__'].__doc__ = USAGE + absl.app.run(main) diff --git a/tensorflow/lite/micro/compression/view_test.py b/tensorflow/lite/micro/compression/view_test.py new file mode 100644 index 00000000000..5befce16dd7 --- /dev/null +++ b/tensorflow/lite/micro/compression/view_test.py @@ -0,0 +1,21 @@ +# Copyright 2025 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import view + +# The test simply makes sure the viewer runs without returning an error. + +model_path = sys.argv[1] +view.print_model(model_path) diff --git a/tensorflow/lite/micro/compression/view_tests.bzl b/tensorflow/lite/micro/compression/view_tests.bzl new file mode 100644 index 00000000000..f43e5d83863 --- /dev/null +++ b/tensorflow/lite/micro/compression/view_tests.bzl @@ -0,0 +1,32 @@ +def generate_view_tests(targets): + """Generates py_test targets for each target's path and a test_suite to + group them. + + Args: + targets: List of target labels to .tflite models with which to test. + """ + test_names = [] + for target in targets: + # Create a test name from the last component of the target name + short_name = target.split(":")[-1] if ":" in target else target.split("/")[-1] + test_name = "view_test_{}".format(short_name.replace(".", "_")) + + native.py_test( + name = test_name, + srcs = ["view_test.py"], + args = ["$(location {})".format(target)], + main = "view_test.py", + data = [target], + deps = [ + ":view", + "@absl_py//absl/testing:absltest", + ], + size = "small", + ) + test_names.append(test_name) + + # Create a test suite for all generated tests + native.test_suite( + name = "view_tests", + tests = test_names, + ) diff --git a/third_party/python_requirements.in b/third_party/python_requirements.in index 581cb423d27..ace27598d31 100644 --- a/third_party/python_requirements.in +++ b/third_party/python_requirements.in @@ -33,5 +33,6 @@ numpy mako pillow yapf +prettyprinter protobuf pyyaml diff --git a/third_party/python_requirements.txt b/third_party/python_requirements.txt index 9ac917bcb20..92fde1ddb5e 100644 --- a/third_party/python_requirements.txt +++ b/third_party/python_requirements.txt @@ -338,6 +338,10 @@ charset-normalizer==3.4.0 \ --hash=sha256:fe9f97feb71aa9896b81973a7bbada8c49501dc73e58a10fcef6663af95e5079 \ --hash=sha256:ffc519621dce0c767e96b9c53f09c5d215578e10b02c285809f76509a3931482 # via requests +colorful==0.5.6 \ + --hash=sha256:b56d5c01db1dac4898308ea889edcb113fbee3e6ec5df4bacffd61d5241b5b8d \ + --hash=sha256:eab8c1c809f5025ad2b5238a50bd691e26850da8cac8f90d660ede6ea1af9f1e + # via prettyprinter cryptography==43.0.3 \ --hash=sha256:0c580952eef9bf68c4747774cde7ec1d85a6e61de97281f2dba83c7d2c806362 \ --hash=sha256:0f996e7268af62598f2fc1204afa98a3b5712313a55c4c9d434aef49cadc91d4 \ @@ -891,6 +895,10 @@ platformdirs==4.3.6 \ --hash=sha256:357fb2acbc885b0419afd3ce3ed34564c13c9b95c89360cd9563f73aa5e2b907 \ --hash=sha256:73e575e1408ab8103900836b97580d5307456908a03e92031bab39e4554cc3fb # via yapf +prettyprinter==0.18.0 \ + --hash=sha256:358a58f276cb312e3ca29d7a7f244c91e4e0bda7848249d30e4f36d2eb58b67c \ + --hash=sha256:9fe5da7ec53510881dd35d7a5c677ba45f34cfe6a8e78d1abd20652cf82139a8 + # via -r third_party/python_requirements.in protobuf==5.28.3 \ --hash=sha256:0c4eec6f987338617072592b97943fdbe30d019c56126493111cf24344c1cc24 \ --hash=sha256:135658402f71bbd49500322c0f736145731b16fc79dc8f367ab544a17eab4535 \ @@ -915,6 +923,7 @@ pygments==2.18.0 \ --hash=sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199 \ --hash=sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a # via + # prettyprinter # readme-renderer # rich pyyaml==6.0.2 \