amazon-transcribe-streaming-sdk icon indicating copy to clipboard operation
amazon-transcribe-streaming-sdk copied to clipboard

Support for AWS Transcribe Medical

Open tmarice opened this issue 4 years ago • 4 comments

Hello,

since AWS released the Medical version of the Transcribe service, it would be great if this SDK natively supported that option too. Since the APIs are very similar, we managed to hack together an ugly version of TranscribeMedicalStreamingClient by just inheriting from TranscribeStreamingClient and performing similar hacks for TranscribeMedicalStreamingRequestSerializer and StartMedicalStreamTranscriptionRequest:

from amazon_transcribe.client import TranscribeStreamingClient
from amazon_transcribe.handlers import TranscriptResultStreamHandler
from amazon_transcribe.httpsession import AwsCrtHttpSessionManager
from amazon_transcribe.model import StartStreamTranscriptionEventStream, StartStreamTranscriptionRequest
from amazon_transcribe.serialize import HEADER_VALUE, Serializer, TranscribeStreamingRequestSerializer
from amazon_transcribe.signer import SigV4RequestSigner
from amazon_transcribe.structures import BufferableByteStream
from amazon_transcribe.utils import _add_required_headers

##


class StartMedicalStreamTranscriptionRequest(StartStreamTranscriptionRequest):
    def __init__(self, *args, **kwargs):
        audio_type = kwargs.pop("audio_type")
        specialty = kwargs.pop("specialty")

        super().__init__(*args, **kwargs)

        self.audio_type = audio_type
        self.specialty = specialty


##


class TranscribeMedicalStreamingRequestSerializer(TranscribeStreamingRequestSerializer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.request_uri = "/medical-stream-transcription"

    def serialize(self) -> Tuple[Dict[str, HEADER_VALUE], BufferedIOBase]:
        headers = {
            "x-amzn-transcribe-language-code": self.request_shape.language_code,
            "x-amzn-transcribe-sample-rate": self.request_shape.media_sample_rate_hz,
            "x-amzn-transcribe-media-encoding": self.request_shape.media_encoding,
            "x-amzn-transcribe-vocabulary-name": self.request_shape.vocabulary_name,
            "x-amzn-transcribe-session-id": self.request_shape.session_id,
            "x-amzn-transcribe-vocabulary-filter-method": self.request_shape.vocab_filter_method,
            "x-amzn-transcribe-vocabulary-filter-name": self.request_shape.vocab_filter_name,
            "x-amzn-transcribe-show-speaker-label": self.request_shape.show_speaker_label,
            "x-amzn-transcribe-enable-channel-identification": self.request_shape.enable_channel_identification,
            "x-amzn-transcribe-number-of-channels": self.request_shape.number_of_channels,
            "x-amzn-transcribe-specialty": self.request_shape.specialty,
            "x-amzn-transcribe-type": self.request_shape.audio_type,
        }

        _add_required_headers(self.endpoint, headers)

        body = BufferableByteStream()
        return headers, body


##


class TranscribeMedicalStreamingClient(TranscribeStreamingClient):
    async def start_stream_transcription(
        self,
        *,
        language_code: str,
        media_sample_rate_hz: int,
        media_encoding: str,
        audio_type: str,
        specialty: str,
        vocabulary_name: Optional[str] = None,
        session_id: Optional[str] = None,
        vocab_filter_method: Optional[str] = None,
        vocab_filter_name: Optional[str] = None,
        show_speaker_label: Optional[bool] = None,
        enable_channel_identification: Optional[bool] = None,
        number_of_channels: Optional[int] = None,
    ) -> StartStreamTranscriptionEventStream:
        transcribe_streaming_request = StartMedicalStreamTranscriptionRequest(
            language_code,
            media_sample_rate_hz,
            media_encoding,
            vocabulary_name,
            session_id,
            vocab_filter_method,
            vocab_filter_name,
            show_speaker_label,
            enable_channel_identification,
            number_of_channels,
            audio_type=audio_type,
            specialty=specialty,
        )
        endpoint = await self._endpoint_resolver.resolve(self.region)
        self._serializer: Serializer = TranscribeMedicalStreamingRequestSerializer(
            endpoint=endpoint,
            transcribe_request=transcribe_streaming_request,
        )
        request = self._serializer.serialize_to_request()

        creds = await self._credential_resolver.get_credentials()
        signer = SigV4RequestSigner("transcribe", self.region)
        signed_request = signer.sign(request, creds)

        session = AwsCrtHttpSessionManager(self._eventloop)

        response = await session.make_request(
            signed_request.uri,
            method=signed_request.method,
            headers=signed_request.headers.as_list(),
            body=signed_request.body,
        )
        resolved_response = await response.resolve_response()

        status_code = resolved_response.status_code
        if status_code >= 400:
            # We need to close before we can consume the body or this will hang
            signed_request.body.close()
            body_bytes = await response.consume_body()
            raise self._response_parser.parse_exception(resolved_response, body_bytes)
        elif status_code != 200:
            raise RuntimeError("Unexpected status code encountered: %s" % status_code)

        parsed_response = self._response_parser.parse_start_stream_transcription_response(
            resolved_response,
            response,
        )

        # The audio stream is returned as output because it requires
        # the signature from the initial HTTP request to be useable
        audio_stream = self._create_audio_stream(signed_request)
        return StartStreamTranscriptionEventStream(audio_stream, parsed_response)

tmarice avatar Jan 27 '21 15:01 tmarice

Thanks @tmarice! This is awesome. Your code needed some modifications to work with the current latest SDK (0.4.0).

from typing import Optional

from amazon_transcribe.client import TranscribeStreamingClient
from amazon_transcribe.httpsession import AwsCrtHttpSessionManager
from amazon_transcribe.model import StartStreamTranscriptionEventStream, StartStreamTranscriptionRequest
from amazon_transcribe.serialize import TranscribeStreamingSerializer
from amazon_transcribe.signer import SigV4RequestSigner
from amazon_transcribe.request import Request

##


class StartMedicalStreamTranscriptionRequest(StartStreamTranscriptionRequest):
    def __init__(self, *args, **kwargs):
        audio_type = kwargs.pop("audio_type")
        specialty = kwargs.pop("specialty")

        super().__init__(*args, **kwargs)

        self.audio_type = audio_type
        self.specialty = specialty


##


class TranscribeMedicalStreamingSerializer(TranscribeStreamingSerializer):
    def __init__(self):
        super().__init__()

        self.request_uri = "/medical-stream-transcription"

    def serialize_start_stream_transcription_request(
        self, endpoint: str, request_shape: StartStreamTranscriptionRequest
    ) -> Request:
        request = super().serialize_start_stream_transcription_request(endpoint, request_shape)
        request.path = self.request_uri

        request.headers.update(
            super()._serialize_str_header(
                "specialty", request_shape.specialty
            )
        )
        
        request.headers.update(
            super()._serialize_str_header(
                "type", request_shape.audio_type
            )
        )

        return request

##


class TranscribeMedicalStreamingClient(TranscribeStreamingClient):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._serializer = TranscribeMedicalStreamingSerializer()

    async def start_stream_transcription(
        self,
        *,
        language_code: str,
        media_sample_rate_hz: int,
        media_encoding: str,
        audio_type: str,
        specialty: str,
        vocabulary_name: Optional[str] = None,
        session_id: Optional[str] = None,
        vocab_filter_method: Optional[str] = None,
        vocab_filter_name: Optional[str] = None,
        show_speaker_label: Optional[bool] = None,
        enable_channel_identification: Optional[bool] = None,
        number_of_channels: Optional[int] = None,
    ) -> StartStreamTranscriptionEventStream:
        transcribe_streaming_request = StartMedicalStreamTranscriptionRequest(
            language_code,
            media_sample_rate_hz,
            media_encoding,
            vocabulary_name,
            session_id,
            vocab_filter_method,
            vocab_filter_name,
            show_speaker_label,
            enable_channel_identification,
            number_of_channels,
            audio_type=audio_type,
            specialty=specialty,
        )
        endpoint = await self._endpoint_resolver.resolve(self.region)

        ## super
        request = self._serializer.serialize_start_stream_transcription_request(
            endpoint=endpoint, request_shape=transcribe_streaming_request,
        ).prepare()

        creds = await self._credential_resolver.get_credentials()
        signer = SigV4RequestSigner("transcribe", self.region)
        signed_request = signer.sign(request, creds)

        session = AwsCrtHttpSessionManager(self._eventloop)

        response = await session.make_request(
            signed_request.uri,
            method=signed_request.method,
            headers=signed_request.headers.as_list(),
            body=signed_request.body,
        )
        resolved_response = await response.resolve_response()

        status_code = resolved_response.status_code
        if status_code >= 400:
            # We need to close before we can consume the body or this will hang
            signed_request.body.close()
            body_bytes = await response.consume_body()
            raise self._response_parser.parse_exception(resolved_response, body_bytes)
        elif status_code != 200:
            raise RuntimeError("Unexpected status code encountered: %s" % status_code)

        parsed_response = self._response_parser.parse_start_stream_transcription_response(
            resolved_response,
            response,
        )

        # The audio stream is returned as output because it requires
        # the signature from the initial HTTP request to be useable
        audio_stream = self._create_audio_stream(signed_request)
        return StartStreamTranscriptionEventStream(audio_stream, parsed_response)

mikeballou-augmedix avatar May 18 '21 22:05 mikeballou-augmedix

Hey guys, any updates about that?

david-oliveira-br avatar Feb 22 '22 12:02 david-oliveira-br

Any updates? Surprised this already hasn't been built

vikramsubramanian avatar May 31 '23 00:05 vikramsubramanian

Any updates? Also I'd like support for show_speaker_labels.

alexe0336 avatar May 30 '24 20:05 alexe0336