Numpy expand dims – Python NumPy expand_dims() Function

NumPy expand_dims() Function:

Numpy expand dims: The expand_dims() function of NumPy module expands an array’s shape(shape of an array). It adds or inserts a new axis to the extended array shape, which will appear at the axis position.

Syntax: 

numpy.expand_dims(a, axis)

Parameters

a:  This is required. It is an input array.

axis: This is required. It indicates the position in the expanded axes where the new axis (or axes) is to be inserted. Int or a tuple of ints can be used.

Return Value:

NP expand dims: It returns an Array[ndArray]. The output array has one more dimension than the input array.

NumPy expand_dims() Function in Python

Example1: axis=0

Approach:

  • Import numpy module using the import keyword.
  • Pass the list as an argument to the array() function to create an array.
  • Store it in a variable.
  • Pass the above array and axis=0 as arguments to the expand_dims() to expand the dimensions of the given array on axis=0.
  • Store it in another variable.
  • Print the given array.
  • Print the shape of the given array using the shape attribute.
  • Print the result i.e expanded array.
  • Print the shape of the expanded array using the shape attribute.
  • The Exit of the Program.

Below is the implementation:

# Import numpy module using the import keyword
import numpy as np
# Pass the list as an argument to the array() function to create an array.
# Store it in a variable.
gvn_arry = np.array([5, 8, 9])
# Pass the above array and axis=0 as arguments to the expand_dims() to expand 
# the dimensions of the given array on axis=0.
# Store it in another variable.
expnd_arry = np.expand_dims(gvn_arry, axis=0)
# Print the given array
print("The given array:\n", gvn_arry)
# Print the shape of the given array using the shape attribute
print("The shape of the given array:", gvn_arry.shape)
# Print the result i.e expanded array
print("The expanded array:\n", expnd_arry)
# Print the shape of the expanded array using the shape attribute
print("The shape of the expanded array:", expnd_arry.shape)

Output:

The given array:
[5 8 9]
The shape of the given array: (3,)
The expanded array:
[[5 8 9]]
The shape of the expanded array: (1, 3)

Example2: axis=1

Approach:

  • Import numpy module using the import keyword.
  • Pass the list as an argument to the array() function to create an array.
  • Store it in a variable.
  • Pass the above array and axis=1 as arguments to the expand_dims() to expand the dimensions of the given array on axis=1.
  • Store it in another variable.
  • Print the given array.
  • Print the shape of the given array using the shape attribute.
  • Print the result i.e expanded array.
  • Print the shape of the expanded array using the shape attribute.
  • The Exit of the Program.

Below is the implementation:

# Import numpy module using the import keyword
import numpy as np
# Pass the list as an argument to the array() function to create an array.
# Store it in a variable.
gvn_arry = np.array([5, 8, 9])
# Pass the above array and axis=1 as arguments to the expand_dims() to expand 
# the dimensions of the given array on axis=1.
# Store it in another variable.
expnd_arry = np.expand_dims(gvn_arry, axis=1)
# Print the given array
print("The given array:\n", gvn_arry)
# Print the shape of the given array using the shape attribute
print("The shape of the given array:", gvn_arry.shape)
# Print the result i.e expanded array
print("The expanded array:\n", expnd_arry)
# Print the shape of the expanded array using the shape attribute
print("The shape of the expanded array:", expnd_arry.shape)

Output:

The given array:
[5 8 9]
The shape of the given array: (3,)
The expanded array:
[[5]
[8]
[9]]
The shape of the expanded array: (3, 1)

Example3: Both axis= (0,1)

Approach:

  • Import numpy module using the import keyword.
  • Pass the list as an argument to the array() function to create an array.
  • Store it in a variable.
  • Pass the above array and axis=(0,1) as arguments to the expand_dims() to expand the dimensions of the given array on axis=0 and 1.
  • Store it in another variable.
  • Print the given array.
  • Print the shape of the given array using the shape attribute.
  • Print the result i.e expanded array.
  • Print the shape of the expanded array using the shape attribute.
  • The Exit of the Program.

Below is the implementation:

# Import numpy module using the import keyword
import numpy as np
# Pass the list as an argument to the array() function to create an array.
# Store it in a variable.
gvn_arry = np.array([5, 8, 9])
# Pass the above array and axis=(0,1) as arguments to the expand_dims() to expand 
# the dimensions of the given array on axis= 0 and 1.
# Store it in another variable.
expnd_arry = np.expand_dims(gvn_arry, axis=(0,1))
# Print the given array
print("The given array:\n", gvn_arry)
# Print the shape of the given array using the shape attribute
print("The shape of the given array:", gvn_arry.shape)
# Print the result i.e expanded array
print("The expanded array:\n", expnd_arry)
# Print the shape of the expanded array using the shape attribute
print("The shape of the expanded array:", expnd_arry.shape)

Output:

The given array:
[5 8 9]
The shape of the given array: (3,)
The expanded array:
[[[5 8 9]]]
The shape of the expanded array: (1, 1, 3)