Skip to content

別アカウントの Bedrock を利用するためのクロスアカウント設定を追加 #444

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions packages/cdk/cdk.json
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
"anonymousUsageTracking": true,
"recognizeFileEnabled": false,
"vpcId": null,
"roleArn": "",
"sessionName": "BedrockApiAccess",
"@aws-cdk/aws-lambda:recognizeLayerVersion": true,
"@aws-cdk/core:checkSecretUsage": true,
"@aws-cdk/core:target-partitions": [
Expand Down
58 changes: 55 additions & 3 deletions packages/cdk/lambda/utils/bedrockApi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,59 @@ import {
UnrecordedMessage,
} from 'generative-ai-use-cases-jp';
import { BEDROCK_MODELS, BEDROCK_IMAGE_GEN_MODELS } from './models';
import { STSClient, AssumeRoleCommand, } from "@aws-sdk/client-sts";

const client = new BedrockRuntimeClient({
region: process.env.MODEL_REGION,
});
// STSから一時的な認証情報を取得する関数を追加
const assumeRole = async (roleArn: string, sessionName: string) => {
const stsClient = new STSClient({ region: process.env.MODEL_REGION });
const command = new AssumeRoleCommand({
RoleArn: roleArn,
RoleSessionName: sessionName,
});

try {
const response = await stsClient.send(command);
if (response.Credentials) {
return {
accessKeyId: response.Credentials?.AccessKeyId,
secretAccessKey: response.Credentials?.SecretAccessKey,
sessionToken: response.Credentials?.SessionToken,
};
} else {
throw new Error("認証情報を取得できませんでした。");
}
} catch (error) {
console.error("Error assuming role: ", error);
throw error;
}
};

// BedrockRuntimeClientを初期化する関数
const initBedrockClient = async () => {
// ROLE_ARN が設定されているかチェック
if (process.env.ROLE_ARN && process.env.SESSION_NAME) {
// STS から一時的な認証情報を取得してクライアントを初期化
const tempCredentials = await assumeRole(process.env.ROLE_ARN, process.env.SESSION_NAME);

if (!tempCredentials.accessKeyId || !tempCredentials.secretAccessKey || !tempCredentials.sessionToken) {
throw new Error("STSからの認証情報が不完全です。");
}

return new BedrockRuntimeClient({
region: process.env.MODEL_REGION,
credentials: {
accessKeyId: tempCredentials.accessKeyId,
secretAccessKey: tempCredentials.secretAccessKey,
sessionToken: tempCredentials.sessionToken,
}
});
} else {
// STSを使用しない場合のクライアント初期化
return new BedrockRuntimeClient({
region: process.env.MODEL_REGION,
});
}
};

const createBodyText = (
model: string,
Expand Down Expand Up @@ -49,6 +98,7 @@ const extractOutputImage = (

const bedrockApi: ApiInterface = {
invoke: async (model, messages) => {
const client = await initBedrockClient();
const command = new InvokeModelCommand({
modelId: model.modelId,
body: createBodyText(model.modelId, messages),
Expand All @@ -59,6 +109,7 @@ const bedrockApi: ApiInterface = {
return extractOutputText(model.modelId, body);
},
invokeStream: async function* (model, messages) {
const client = await initBedrockClient();
try {
const command = new InvokeModelWithResponseStreamCommand({
modelId: model.modelId,
Expand Down Expand Up @@ -98,6 +149,7 @@ const bedrockApi: ApiInterface = {
}
},
generateImage: async (model, params) => {
const client = await initBedrockClient();
const command = new InvokeModelCommand({
modelId: model.modelId,
body: createBodyImage(model.modelId, params),
Expand Down
54 changes: 45 additions & 9 deletions packages/cdk/lib/construct/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ export class Api extends Construct {
readonly imageGenerationModelIds: string[];
readonly endpointNames: string[];
readonly agentNames: string[];
readonly roleArn: string;
readonly sessionName: string

constructor(scope: Construct, id: string, props: BackendApiProps) {
super(scope, id);
Expand Down Expand Up @@ -93,6 +95,10 @@ export class Api extends Construct {
};
}

// cross account access IAM role
const roleArn = this.node.tryGetContext('roleArn');
const sessionName = this.node.tryGetContext('sessionName');

// Lambda
const predictFunction = new NodejsFunction(this, 'Predict', {
runtime: Runtime.NODEJS_18_X,
Expand All @@ -102,6 +108,8 @@ export class Api extends Construct {
MODEL_REGION: modelRegion,
MODEL_IDS: JSON.stringify(modelIds),
IMAGE_GENERATION_MODEL_IDS: JSON.stringify(imageGenerationModelIds),
ROLE_ARN: roleArn,
SESSION_NAME: sessionName,
},
bundling: {
nodeModules: ['@aws-sdk/client-bedrock-runtime'],
Expand All @@ -118,6 +126,8 @@ export class Api extends Construct {
IMAGE_GENERATION_MODEL_IDS: JSON.stringify(imageGenerationModelIds),
AGENT_REGION: agentRegion,
AGENT_MAP: JSON.stringify(agentMap),
ROLE_ARN: roleArn,
SESSION_NAME: sessionName,
},
bundling: {
nodeModules: [
Expand All @@ -144,6 +154,8 @@ export class Api extends Construct {
MODEL_REGION: modelRegion,
MODEL_IDS: JSON.stringify(modelIds),
IMAGE_GENERATION_MODEL_IDS: JSON.stringify(imageGenerationModelIds),
ROLE_ARN: roleArn,
SESSION_NAME: sessionName,
},
});
table.grantWriteData(predictTitleFunction);
Expand All @@ -156,6 +168,8 @@ export class Api extends Construct {
MODEL_REGION: modelRegion,
MODEL_IDS: JSON.stringify(modelIds),
IMAGE_GENERATION_MODEL_IDS: JSON.stringify(imageGenerationModelIds),
ROLE_ARN: roleArn,
SESSION_NAME: sessionName,
},
bundling: {
nodeModules: ['@aws-sdk/client-bedrock-runtime'],
Expand Down Expand Up @@ -183,15 +197,37 @@ export class Api extends Construct {

// Bedrock は常に権限付与
// Bedrock Policy
const bedrockPolicy = new PolicyStatement({
effect: Effect.ALLOW,
resources: ['*'],
actions: ['bedrock:*', 'logs:*'],
});
predictStreamFunction.role?.addToPrincipalPolicy(bedrockPolicy);
predictFunction.role?.addToPrincipalPolicy(bedrockPolicy);
predictTitleFunction.role?.addToPrincipalPolicy(bedrockPolicy);
generateImageFunction.role?.addToPrincipalPolicy(bedrockPolicy);
if (typeof roleArn !== 'string' || roleArn === '') {
const bedrockPolicy = new PolicyStatement({
effect: Effect.ALLOW,
resources: ['*'],
actions: ['bedrock:*', 'logs:*'],
});
predictStreamFunction.role?.addToPrincipalPolicy(bedrockPolicy);
predictFunction.role?.addToPrincipalPolicy(bedrockPolicy);
predictTitleFunction.role?.addToPrincipalPolicy(bedrockPolicy);
generateImageFunction.role?.addToPrincipalPolicy(bedrockPolicy);
} else {
// roleArn が指定されている場合のポリシー
const logsPolicy = new PolicyStatement({
effect: Effect.ALLOW,
actions: ['logs:*'],
resources: ['*'],
});
const assumeRolePolicy = new PolicyStatement({
effect: Effect.ALLOW,
actions: ['sts:AssumeRole'],
resources: [roleArn],
});
predictStreamFunction.role?.addToPrincipalPolicy(logsPolicy);
predictFunction.role?.addToPrincipalPolicy(logsPolicy);
predictTitleFunction.role?.addToPrincipalPolicy(logsPolicy);
generateImageFunction.role?.addToPrincipalPolicy(logsPolicy);
predictStreamFunction.role?.addToPrincipalPolicy(assumeRolePolicy);
predictFunction.role?.addToPrincipalPolicy(assumeRolePolicy);
predictTitleFunction.role?.addToPrincipalPolicy(assumeRolePolicy);
generateImageFunction.role?.addToPrincipalPolicy(assumeRolePolicy);
}

const createChatFunction = new NodejsFunction(this, 'CreateChat', {
runtime: Runtime.NODEJS_18_X,
Expand Down