Skip to content

Commit 625b9c2

Browse files
committed
別アカウントの Bedrock を利用するためのクロスアカウント設定を追加
1 parent 2275212 commit 625b9c2

File tree

3 files changed

+102
-12
lines changed

3 files changed

+102
-12
lines changed

packages/cdk/cdk.json

+2
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
"anonymousUsageTracking": true,
4949
"recognizeFileEnabled": false,
5050
"vpcId": null,
51+
"roleArn": "",
52+
"sessionName": "BedrockApiAccess",
5153
"@aws-cdk/aws-lambda:recognizeLayerVersion": true,
5254
"@aws-cdk/core:checkSecretUsage": true,
5355
"@aws-cdk/core:target-partitions": [

packages/cdk/lambda/utils/bedrockApi.ts

+55-3
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,59 @@ import {
1313
UnrecordedMessage,
1414
} from 'generative-ai-use-cases-jp';
1515
import { BEDROCK_MODELS, BEDROCK_IMAGE_GEN_MODELS } from './models';
16+
import { STSClient, AssumeRoleCommand, } from "@aws-sdk/client-sts";
1617

17-
const client = new BedrockRuntimeClient({
18-
region: process.env.MODEL_REGION,
19-
});
18+
// STSから一時的な認証情報を取得する関数を追加
19+
const assumeRole = async (roleArn: string, sessionName: string) => {
20+
const stsClient = new STSClient({ region: process.env.MODEL_REGION });
21+
const command = new AssumeRoleCommand({
22+
RoleArn: roleArn,
23+
RoleSessionName: sessionName,
24+
});
25+
26+
try {
27+
const response = await stsClient.send(command);
28+
if (response.Credentials) {
29+
return {
30+
accessKeyId: response.Credentials?.AccessKeyId,
31+
secretAccessKey: response.Credentials?.SecretAccessKey,
32+
sessionToken: response.Credentials?.SessionToken,
33+
};
34+
} else {
35+
throw new Error("認証情報を取得できませんでした。");
36+
}
37+
} catch (error) {
38+
console.error("Error assuming role: ", error);
39+
throw error;
40+
}
41+
};
42+
43+
// BedrockRuntimeClientを初期化する関数
44+
const initBedrockClient = async () => {
45+
// ROLE_ARN が設定されているかチェック
46+
if (process.env.ROLE_ARN && process.env.SESSION_NAME) {
47+
// STS から一時的な認証情報を取得してクライアントを初期化
48+
const tempCredentials = await assumeRole(process.env.ROLE_ARN, process.env.SESSION_NAME);
49+
50+
if (!tempCredentials.accessKeyId || !tempCredentials.secretAccessKey || !tempCredentials.sessionToken) {
51+
throw new Error("STSからの認証情報が不完全です。");
52+
}
53+
54+
return new BedrockRuntimeClient({
55+
region: process.env.MODEL_REGION,
56+
credentials: {
57+
accessKeyId: tempCredentials.accessKeyId,
58+
secretAccessKey: tempCredentials.secretAccessKey,
59+
sessionToken: tempCredentials.sessionToken,
60+
}
61+
});
62+
} else {
63+
// STSを使用しない場合のクライアント初期化
64+
return new BedrockRuntimeClient({
65+
region: process.env.MODEL_REGION,
66+
});
67+
}
68+
};
2069

2170
const createBodyText = (
2271
model: string,
@@ -49,6 +98,7 @@ const extractOutputImage = (
4998

5099
const bedrockApi: ApiInterface = {
51100
invoke: async (model, messages) => {
101+
const client = await initBedrockClient();
52102
const command = new InvokeModelCommand({
53103
modelId: model.modelId,
54104
body: createBodyText(model.modelId, messages),
@@ -59,6 +109,7 @@ const bedrockApi: ApiInterface = {
59109
return extractOutputText(model.modelId, body);
60110
},
61111
invokeStream: async function* (model, messages) {
112+
const client = await initBedrockClient();
62113
try {
63114
const command = new InvokeModelWithResponseStreamCommand({
64115
modelId: model.modelId,
@@ -98,6 +149,7 @@ const bedrockApi: ApiInterface = {
98149
}
99150
},
100151
generateImage: async (model, params) => {
152+
const client = await initBedrockClient();
101153
const command = new InvokeModelCommand({
102154
modelId: model.modelId,
103155
body: createBodyImage(model.modelId, params),

packages/cdk/lib/construct/api.ts

+45-9
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ export class Api extends Construct {
3131
readonly imageGenerationModelIds: string[];
3232
readonly endpointNames: string[];
3333
readonly agentNames: string[];
34+
readonly roleArn: string;
35+
readonly sessionName: string
3436

3537
constructor(scope: Construct, id: string, props: BackendApiProps) {
3638
super(scope, id);
@@ -93,6 +95,10 @@ export class Api extends Construct {
9395
};
9496
}
9597

98+
// cross account access IAM role
99+
const roleArn = this.node.tryGetContext('roleArn');
100+
const sessionName = this.node.tryGetContext('sessionName');
101+
96102
// Lambda
97103
const predictFunction = new NodejsFunction(this, 'Predict', {
98104
runtime: Runtime.NODEJS_18_X,
@@ -102,6 +108,8 @@ export class Api extends Construct {
102108
MODEL_REGION: modelRegion,
103109
MODEL_IDS: JSON.stringify(modelIds),
104110
IMAGE_GENERATION_MODEL_IDS: JSON.stringify(imageGenerationModelIds),
111+
ROLE_ARN: roleArn,
112+
SESSION_NAME: sessionName,
105113
},
106114
bundling: {
107115
nodeModules: ['@aws-sdk/client-bedrock-runtime'],
@@ -118,6 +126,8 @@ export class Api extends Construct {
118126
IMAGE_GENERATION_MODEL_IDS: JSON.stringify(imageGenerationModelIds),
119127
AGENT_REGION: agentRegion,
120128
AGENT_MAP: JSON.stringify(agentMap),
129+
ROLE_ARN: roleArn,
130+
SESSION_NAME: sessionName,
121131
},
122132
bundling: {
123133
nodeModules: [
@@ -144,6 +154,8 @@ export class Api extends Construct {
144154
MODEL_REGION: modelRegion,
145155
MODEL_IDS: JSON.stringify(modelIds),
146156
IMAGE_GENERATION_MODEL_IDS: JSON.stringify(imageGenerationModelIds),
157+
ROLE_ARN: roleArn,
158+
SESSION_NAME: sessionName,
147159
},
148160
});
149161
table.grantWriteData(predictTitleFunction);
@@ -156,6 +168,8 @@ export class Api extends Construct {
156168
MODEL_REGION: modelRegion,
157169
MODEL_IDS: JSON.stringify(modelIds),
158170
IMAGE_GENERATION_MODEL_IDS: JSON.stringify(imageGenerationModelIds),
171+
ROLE_ARN: roleArn,
172+
SESSION_NAME: sessionName,
159173
},
160174
bundling: {
161175
nodeModules: ['@aws-sdk/client-bedrock-runtime'],
@@ -183,15 +197,37 @@ export class Api extends Construct {
183197

184198
// Bedrock は常に権限付与
185199
// Bedrock Policy
186-
const bedrockPolicy = new PolicyStatement({
187-
effect: Effect.ALLOW,
188-
resources: ['*'],
189-
actions: ['bedrock:*', 'logs:*'],
190-
});
191-
predictStreamFunction.role?.addToPrincipalPolicy(bedrockPolicy);
192-
predictFunction.role?.addToPrincipalPolicy(bedrockPolicy);
193-
predictTitleFunction.role?.addToPrincipalPolicy(bedrockPolicy);
194-
generateImageFunction.role?.addToPrincipalPolicy(bedrockPolicy);
200+
if (typeof roleArn !== 'string' || roleArn === '') {
201+
const bedrockPolicy = new PolicyStatement({
202+
effect: Effect.ALLOW,
203+
resources: ['*'],
204+
actions: ['bedrock:*', 'logs:*'],
205+
});
206+
predictStreamFunction.role?.addToPrincipalPolicy(bedrockPolicy);
207+
predictFunction.role?.addToPrincipalPolicy(bedrockPolicy);
208+
predictTitleFunction.role?.addToPrincipalPolicy(bedrockPolicy);
209+
generateImageFunction.role?.addToPrincipalPolicy(bedrockPolicy);
210+
} else {
211+
// roleArn が指定されている場合のポリシー
212+
const logsPolicy = new PolicyStatement({
213+
effect: Effect.ALLOW,
214+
actions: ['logs:*'],
215+
resources: ['*'],
216+
});
217+
const assumeRolePolicy = new PolicyStatement({
218+
effect: Effect.ALLOW,
219+
actions: ['sts:AssumeRole'],
220+
resources: [roleArn],
221+
});
222+
predictStreamFunction.role?.addToPrincipalPolicy(logsPolicy);
223+
predictFunction.role?.addToPrincipalPolicy(logsPolicy);
224+
predictTitleFunction.role?.addToPrincipalPolicy(logsPolicy);
225+
generateImageFunction.role?.addToPrincipalPolicy(logsPolicy);
226+
predictStreamFunction.role?.addToPrincipalPolicy(assumeRolePolicy);
227+
predictFunction.role?.addToPrincipalPolicy(assumeRolePolicy);
228+
predictTitleFunction.role?.addToPrincipalPolicy(assumeRolePolicy);
229+
generateImageFunction.role?.addToPrincipalPolicy(assumeRolePolicy);
230+
}
195231

196232
const createChatFunction = new NodejsFunction(this, 'CreateChat', {
197233
runtime: Runtime.NODEJS_18_X,

0 commit comments

Comments
 (0)