Skip to content

05-model-evaluation-and-improvement.ipynb heatmap error #184

Open
@stevehua

Description

@stevehua

Running cell 34 (where mglearn.tools.heatmap() is called):

scores = np.array(results.mean_test_score).reshape(6, 6)
 
# plot the mean cross-validation scores
mglearn.tools.heatmap(scores, xlabel='gamma', xticklabels=param_grid['gamma'],
                       ylabel='C', yticklabels=param_grid['C'], cmap="viridis")
 

Produces the following error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[34], line 4
      1 scores = np.array(results.mean_test_score).reshape(6, 6)
      3 # plot the mean cross-validation scores
----> 4 mglearn.tools.heatmap(scores, xlabel='gamma', xticklabels=param_grid['gamma'],
      5                       ylabel='C', yticklabels=param_grid['C'], cmap="viridis")

File C:\my_venv\Lib\site-packages\mglearn\tools.py:80, in heatmap(values, xlabel, ylabel, xticklabels, yticklabels, cmap, vmin, vmax, ax, fmt)
     78     else:
     79         c = 'w'
---> 80     ax.text(x, y, fmt % value, color=c, ha="center", va="center")
     81 return img

File C:\my_venv\Lib\site-packages\numpy\ma\core.py:4558, in MaskedArray.__float__(self)
   4553 """
   4554 Convert to float.
   4555 
   4556 """
   4557 if self.size > 1:
-> 4558     raise TypeError("Only length-1 arrays can be converted "
   4559                     "to Python scalars")
   4560 elif self._mask:
   4561     warnings.warn("Warning: converting a masked element to nan.", stacklevel=2)

TypeError: Only length-1 arrays can be converted to Python scalars

Output of sklearn.show_versions():

System:
    python: 3.13.2 (tags/v3.13.2:4f8bb39, Feb  4 2025, 15:23:48) [MSC v.1942 64 bit (AMD64)]
executable: C:\my_venv\Scripts\python.exe
   machine: Windows-10-10.0.19045-SP0

Python dependencies:
      sklearn: 1.6.1
          pip: 25.0.1
   setuptools: 75.8.0
        numpy: 2.2.3
        scipy: 1.15.1
       Cython: None
       pandas: 2.2.3
   matplotlib: 3.10.0
       joblib: 1.4.2
threadpoolctl: 3.5.0

Built with OpenMP: True

threadpoolctl info:
       user_api: blas
   internal_api: openblas
    num_threads: 8
         prefix: libscipy_openblas
       filepath: C:\my_venv\Lib\site-packages\numpy.libs\libscipy_openblas64_-43e11ff0749b8cbe0a615c9cf6737e0e.dll
        version: 0.3.28
threading_layer: pthreads
   architecture: Haswell

       user_api: openmp
   internal_api: openmp
    num_threads: 8
         prefix: vcomp
       filepath: C:\my_venv\Lib\site-packages\sklearn\.libs\vcomp140.dll
        version: None

       user_api: blas
   internal_api: openblas
    num_threads: 8
         prefix: libscipy_openblas
       filepath: C:\my_venv\Lib\site-packages\scipy.libs\libscipy_openblas-f07f5a5d207a3a47104dca54d6d0c86a.dll
        version: 0.3.28
threading_layer: pthreads
   architecture: Haswell

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions