Skip to content

Commit 80b46d6

Browse files
feat(tpu): add tpu vm list sample. (#9606)
Implemented tpu_vm_list sample, created test
1 parent bb0108c commit 80b46d6

File tree

2 files changed

+78
-0
lines changed

2 files changed

+78
-0
lines changed

tpu/src/main/java/tpu/ListTpuVms.java

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package tpu;
18+
19+
//[START tpu_vm_list]
20+
import com.google.cloud.tpu.v2.ListNodesRequest;
21+
import com.google.cloud.tpu.v2.TpuClient;
22+
import java.io.IOException;
23+
24+
public class ListTpuVms {
25+
26+
public static void main(String[] args) throws IOException {
27+
// TODO(developer): Replace these variables before running the sample.
28+
// Project ID or project number of the Google Cloud project you want to use.
29+
String projectId = "YOUR_PROJECT_ID";
30+
// The zone where the TPUs are located.
31+
// For more information about supported TPU types for specific zones,
32+
// see https://cloud.google.com/tpu/docs/regions-zones
33+
String zone = "us-central1-f";
34+
35+
listTpuVms(projectId, zone);
36+
}
37+
38+
// Lists TPU VMs in the specified zone.
39+
public static TpuClient.ListNodesPage listTpuVms(String projectId, String zone)
40+
throws IOException {
41+
// Initialize client that will be used to send requests. This client only needs to be created
42+
// once, and can be reused for multiple requests.
43+
try (TpuClient tpuClient = TpuClient.create()) {
44+
String parent = String.format("projects/%s/locations/%s", projectId, zone);
45+
46+
ListNodesRequest request = ListNodesRequest.newBuilder().setParent(parent).build();
47+
48+
return tpuClient.listNodes(request).getPage();
49+
}
50+
}
51+
}
52+
//[END tpu_vm_list]

tpu/src/test/java/tpu/TpuVmIT.java

+26
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,15 @@
3030
import com.google.cloud.tpu.v2.CreateNodeRequest;
3131
import com.google.cloud.tpu.v2.DeleteNodeRequest;
3232
import com.google.cloud.tpu.v2.GetNodeRequest;
33+
import com.google.cloud.tpu.v2.ListNodesRequest;
3334
import com.google.cloud.tpu.v2.Node;
3435
import com.google.cloud.tpu.v2.TpuClient;
3536
import com.google.cloud.tpu.v2.TpuSettings;
3637
import java.io.ByteArrayOutputStream;
3738
import java.io.IOException;
3839
import java.io.PrintStream;
40+
import java.util.Arrays;
41+
import java.util.List;
3942
import java.util.concurrent.ExecutionException;
4043
import org.junit.jupiter.api.Test;
4144
import org.junit.jupiter.api.Timeout;
@@ -140,4 +143,27 @@ public void testCreateTpuVmWithTopologyFlag()
140143
assertEquals(returnedNode, mockNode);
141144
}
142145
}
146+
147+
@Test
148+
public void testListTpuVm() throws IOException {
149+
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
150+
Node mockNode1 = mock(Node.class);
151+
Node mockNode2 = mock(Node.class);
152+
List<Node> mockListNodes = Arrays.asList(mockNode1, mockNode2);
153+
TpuClient mockTpuClient = mock(TpuClient.class);
154+
TpuClient.ListNodesPagedResponse mockListNodesResponse =
155+
mock(TpuClient.ListNodesPagedResponse.class);
156+
TpuClient.ListNodesPage mockListNodesPage = mock(TpuClient.ListNodesPage.class);
157+
158+
mockedTpuClient.when(TpuClient::create).thenReturn(mockTpuClient);
159+
when(mockTpuClient.listNodes(any(ListNodesRequest.class))).thenReturn(mockListNodesResponse);
160+
when(mockListNodesResponse.getPage()).thenReturn(mockListNodesPage);
161+
when(mockListNodesPage.getValues()).thenReturn(mockListNodes);
162+
163+
TpuClient.ListNodesPage returnedListNodes = ListTpuVms.listTpuVms(PROJECT_ID, ZONE);
164+
165+
assertThat(returnedListNodes.getValues()).isEqualTo(mockListNodes);
166+
verify(mockTpuClient, times(1)).listNodes(any(ListNodesRequest.class));
167+
}
168+
}
143169
}

0 commit comments

Comments
 (0)